Tensorflow训练模型越来越慢的2种解决方案

1 解决方案

【方案一】

载入模型结构放在全局,即tensorflow会话外层。

'''载入模型结构:最关键的一步'''
saver = tf.train.Saver()
'''建立会话'''
with tf.Session() as sess:
 for i in range(STEPS):
 '''开始训练'''
 _, loss_1, acc, summary = sess.run([train_op_1, train_loss, train_acc, summary_op], feed_dict=feed_dict)
 '''保存模型'''
 saver.save(sess, save_path="./model/path", i)

【方案二】

在方案一的基础上,将模型结构放在图会话的外部。

'''预测值'''
train_logits= network_model.inference(inputs, keep_prob)
'''损失值'''
train_loss = network_model.losses(train_logits)
'''优化'''
train_op = network_model.train(train_loss, learning_rate)
'''准确率'''
train_acc = network_model.evaluation(train_logits, labels)
'''模型输入'''
feed_dict = {inputs: x_batch, labels: y_batch, keep_prob: 0.5}
'''载入模型结构'''
saver = tf.train.Saver()
'''建立会话'''
with tf.Session() as sess:
 for i in range(STEPS):
 '''开始训练'''
 _, loss_1, acc, summary = sess.run([train_op_1, train_loss, train_acc, summary_op], feed_dict=feed_dict)
 '''保存模型'''
 saver.save(sess, save_path="./model/path", i) 

2 时间测试

通过不同方法测试训练程序,得到不同的训练时间,每执行一次训练都重新载入图结构,会使每一步的训练时间逐次增加,如果训练步数越大,后面训练速度越来越慢,最终可导致图爆炸,而终止训练。

【时间累加】

2019-05-15 10:55:29.009205: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
step: 0, time cost: 1.8800880908966064
step: 1, time cost: 1.592250108718872
step: 2, time cost: 1.553826093673706
step: 3, time cost: 1.5687050819396973
step: 4, time cost: 1.5777575969696045
step: 5, time cost: 1.5908267498016357
step: 6, time cost: 1.5989274978637695
step: 7, time cost: 1.6078357696533203
step: 8, time cost: 1.6087186336517334
step: 9, time cost: 1.6123006343841553
step: 10, time cost: 1.6320762634277344
step: 11, time cost: 1.6317598819732666
step: 12, time cost: 1.6570467948913574
step: 13, time cost: 1.6584930419921875
step: 14, time cost: 1.6765813827514648
step: 15, time cost: 1.6751370429992676
step: 16, time cost: 1.7304580211639404
step: 17, time cost: 1.7583982944488525

【时间均衡】

2019-05-15 13:03:49.394354: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1115] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:1 with 7048 MB memory) -> physical GPU (device: 1, name: Tesla P4, pci bus id: 0000:00:0d.0, compute capability: 6.1)
step: 0, time cost: 1.9781079292297363
loss1:6.78, loss2:5.47, loss3:5.27, loss4:7.31, loss5:5.44, loss6:6.87, loss7: 6.84
Total loss: 43.98, accuracy: 0.04, steps: 0, time cost: 1.9781079292297363
step: 1, time cost: 0.09688425064086914
step: 2, time cost: 0.09693264961242676
step: 3, time cost: 0.09671926498413086
step: 4, time cost: 0.09688210487365723
step: 5, time cost: 0.09646058082580566
step: 6, time cost: 0.09669041633605957
step: 7, time cost: 0.09666872024536133
step: 8, time cost: 0.09651994705200195
step: 9, time cost: 0.09705543518066406
step: 10, time cost: 0.09690332412719727

3 原因分析

(1) Tensorflow使用图结构构建系统,图结构中有节点(node)和边(operation),每次进行计算时会向图中添加边和节点进行计算或者读取已存在的图结构;

(2) 使用图结构也是一把双刃之剑,可以加快计算和提高设计效率,但是,程序设计不合理会导向负面,使训练越来约慢;

(3) 训练越来越慢是因为运行一次sess.run,向图中添加一次节点或者重新载入一次图结构,导致图中节点和边越来越多,计算参数也成倍增长;

(4) tf.train.Saver()就是载入图结构的类,因此设计训练程序时,若每执行一次跟新就使用该类载入图结构,自然会增加参数数量,必然导致训练变慢;

(5) 因此,将载入图结构的类放在全局,即只载入一次图结构,其他时间只训练图结构中的参数,可保持原有的训练速度;

4 总结

(1) 设计训练网络,只载入一次图结构即可;

(2) tf.train.Saver()就是载入图结构的类,将该类的实例化放在全局,即会话外部,解决训练越来越慢。

以上这篇Tensorflow训练模型越来越慢的2种解决方案就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • TensorFlow实现保存训练模型为pd文件并恢复

    TensorFlow保存模型代码 import tensorflow as tf from tensorflow.python.framework import graph_util var1 = tf.Variable(1.0, dtype=tf.float32, name='v1') var2 = tf.Variable(2.0, dtype=tf.float32, name='v2') var3 = tf.Variable(2.0, dtype=tf.float32, name='v3')

  • 从训练好的tensorflow模型中打印训练变量实例

    从tensorflow 训练后保存的模型中打印训变量:使用tf.train.NewCheckpointReader() import tensorflow as tf reader = tf.train.NewCheckpointReader('path/alexnet/model-330000') dic = reader.get_variable_to_shape_map() print dic 打印变量 w = reader.get_tensor("fc1/W") print t

  • tensorflow模型继续训练 fineturn实例

    解决tensoflow如何在已训练模型上继续训练fineturn的问题. 训练代码 任务描述: x = 3.0, y = 100.0, 运算公式 x×W+b = y,求 W和b的最优解. # -*- coding: utf-8 -*-) import tensorflow as tf # 声明占位变量x.y x = tf.placeholder("float", shape=[None, 1]) y = tf.placeholder("float", [None,

  • tensorflow获取预训练模型某层参数并赋值到当前网络指定层方式

    已经有了一个预训练的模型,我需要从其中取出某一层,把该层的weights和biases赋值到新的网络结构中,可以使用tensorflow中的pywrap_tensorflow(用来读取预训练模型的参数值)结合Session.assign()进行操作. 这种需求即预训练模型可能为单分支网络,当前网络为多分支,我需要把单分支A复用到到多个分支去(B,C,D). 先导入对应的工具包 from tensorflow.python import pywrap_tensorflow 接下来的操作在一个tf.

  • Tensorflow实现在训练好的模型上进行测试

    Tensorflow可以使用训练好的模型对新的数据进行测试,有两种方法:第一种方法是调用模型和训练在同一个py文件中,中情况比较简单:第二种是训练过程和调用模型过程分别在两个py文件中.本文将讲解第二种方法. 模型的保存 tensorflow提供可保存训练模型的接口,使用起来也不是很难,直接上代码讲解: #网络结构 w1 = tf.Variable(tf.truncated_normal([in_units, h1_units], stddev=0.1)) b1 = tf.Variable(tf

  • tensorflow如何继续训练之前保存的模型实例

    一:需重定义神经网络继续训练的方法 1.训练代码 import numpy as np import tensorflow as tf x_data=np.random.rand(100).astype(np.float32) y_data=x_data*0.1+0.3 weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w") biases=tf.Variable(tf.zeros([1]),name="b&qu

  • Python通过TensorFLow进行线性模型训练原理与实现方法详解

    本文实例讲述了Python通过TensorFLow进行线性模型训练原理与实现方法.分享给大家供大家参考,具体如下: 1.相关概念 例如要从一个线性分布的途中抽象出其y=kx+b的分布规律 特征是输入变量,即简单线性回归中的 x 变量.简单的机器学习项目可能会使用单个特征,而比较复杂的机器学习项目可能会使用数百万个特征. 标签是我们要预测的事物,即简单线性回归中的 y 变量. 样本是指具体的数据实例.有标签样本是指具有{特征,标签}的数据,用于训练模型,总结规律.无标签样本只具有特征的数据x,通过

  • Tensorflow训练模型越来越慢的2种解决方案

    1 解决方案 [方案一] 载入模型结构放在全局,即tensorflow会话外层. '''载入模型结构:最关键的一步''' saver = tf.train.Saver() '''建立会话''' with tf.Session() as sess: for i in range(STEPS): '''开始训练''' _, loss_1, acc, summary = sess.run([train_op_1, train_loss, train_acc, summary_op], feed_dic

  • 关于MySQL与Golan分布式事务经典的七种解决方案

    目录 1.基础理论 1.1 事务 1.2 分布式事务 2.分布式事务的解决方案 2.1 两阶段提交/XA 2.2 SAGA 2.3 TCC 2.4 本地消息表 2.5 事务消息 2.6 最大努力通知 2.7 AT事务模式 3.异常处理 3.1 异常情况 3.2 子事务屏障 3.3 子事务屏障原理 3.4 子事务屏障小结 4.分布式事务实践 4.1 一个SAGA事务 4.2 处理网络异常 4.3 处理回滚 5.总结 前言: 随着业务的快速发展.业务复杂度越来越高,几乎每个公司的系统都会从单体走向分

  • SQL Server 2008 R2占用cpu、内存越来越大的两种解决方法

    SQL Server 2008 R2运行越久,占用内存会越来越大. 第一种: 有了上边的分析结果,解决方法就简单了,定期重启下SQL Server 2008 R2数据库服务即可,使用任务计划定期执行下边批处理: net stop sqlserveragent net stop mssqlserver net start mssqlserver net start sqlserveragent 第二种: 进入Sql server 企业管理器(管理数据库和表的,这个都不知道就不用往下看了),在数据库

  • Java中实现文件上传下载的三种解决方案(推荐)

    java文件上传与文件下载是程序开发中比较常见的功能,下面通过本文给大家介绍Java中实现文件上传下载的三种解决方案,具体详情如下所示: 第一点:Java代码实现文件上传 FormFile file=manform.getFile(); String newfileName = null; String newpathname=null; String fileAddre="/numUp"; try { InputStream stream = file.getInputStream(

  • Json返回时间的格式中出现乱码问题的两种解决方案

    前言:这段时间一直没有写博客,首先是我正在实现权限系列的绝色和操作的实现,因为这些东西在前面我们都已经说过了,所以我们就不重复的说这些了,那么我们知道,在我们使用Json返回数据的时候时间的格式一般都会变了,变成我们不认识的一些字符,那么当我们遇到这些问题的时候我们该怎么解决呢,今天我就来小说一下这个的解决方法. .发现问题 (1).正如我们在前言里面所说,我们在编写Json解析时间的时候会返回一些莫名其妙的东西,那么我们是如何解决这个问题的呢?我现在有两种方法可以解决这个问题,下面我们首先来说

  • 详解js跨域原理以及2种解决方案

    1.什么是跨域 我们经常会在页面上使用ajax请求访问其他服务器的数据,此时,客户端会出现跨域问题. 跨域问题是由于javascript语言安全限制中的同源策略造成的. 简单来说,同源策略是指一段脚本只能读取来自同一来源的窗口和文档的属性,这里的同一来源指的是主机名.协议和端口号的组合. 例如: 2.实现原理 在HTML DOM中,Script标签是可以跨域访问服务器上的数据的.因此,可以指定script的src属性为跨域的url,从而实现跨域访问. 例如: 这种访问方式是不行的.但是如下方式,

  • 详解解决Python memory error的问题(四种解决方案)

    昨天在用用Pycharm读取一个200+M的CSV的过程中,竟然出现了Memory Error!简直让我怀疑自己买了个假电脑,毕竟是8G内存i7处理器,一度怀疑自己装了假的内存条....下面说一下几个解题步骤....一般就是用下面这些方法了,按顺序试试. 一.逐行读取 如果你用pd.read_csv来读文件,会一次性把数据都读到内存里来,导致内存爆掉,那么一个想法就是一行一行地读它,代码如下: data = [] with open(path, 'r',encoding='gbk',errors

  • 详解MySQL双活同步复制四种解决方案

    对于数据实时同步,其核心是需要基于日志来实现,是可以实现准实时的数据同步,基于日志实现不会要求数据库本身在设计和实现中带来任何额外的约束. 基于MySQL原生复制主主同步方案  这是常见的方案,一般来说,中小型规模的时候,采用这种架构是最省事的. 两个节点可以采用简单的双主模式,并且使用专线连接,在master_A节点发生故障后,应用连接快速切换到master_B节点,反之也亦然.有几个需要注意的地方,脑裂的情况,两个节点写入相同数据而引发冲突,同时把两个节点的auto_increment_in

  • Docker容器内不能联网的6种解决方案

    Docker容器内不能联网的6种解决方案 注:下面的方法是在容器内能ping通公网IP的解决方案,如果连公网IP都ping不通,那主机可能也上不了网(尝试ping 8.8.8.8) 1.使用–net:host选项 sudo docker run --net:host --name ubuntu_bash -i -t ubuntu:latest /bin/bash 2.使用–dns选项 sudo docker run --dns 8.8.8.8 --dns 8.8.4.4 --name ubunt

  • 在微信小程序中渲染HTML内容3种解决方案及分析与问题解决

    大部分Web应用的富文本内容都是以HTML字符串的形式存储的,通过HTML文档去展示HTML内容自然没有问题.但是,在微信小程序(下文简称为「小程序」)中,应当如何渲染这部分内容呢? 在微信小程序中渲染HTML内容的3种解决方案 wxParse 小程序刚上线那会儿,是无法直接渲染HTML内容的,于是就诞生了一个叫做「wxParse」的库.它的原理就是把HTML代码解析成树结构的数据,再通过小程序的模板把该数据渲染出来. rich-text 后来,小程序增加了「rich-text」组件用于展示富文

随机推荐