完美解决TensorFlow和Keras大数据量内存溢出的问题

内存溢出问题是参加kaggle比赛或者做大数据量实验的第一个拦路虎。

以前做的练手小项目导致新手产生一个惯性思维——读取训练集图片的时候把所有图读到内存中,然后分批训练。

其实这是有问题的,很容易导致OOM。现在内存一般16G,而训练集图片通常是上万张,而且RGB图,还很大,VGG16的图片一般是224x224x3,上万张图片,16G内存根本不够用。这时候又会想起——设置batch,但是那个batch的输入参数却又是图片,它只是把传进去的图片分批送到显卡,而我OOM的地方恰是那个“传进去”的图片,怎么办?

解决思路其实说来也简单,打破思维定式就好了,不是把所有图片读到内存中,而是只把所有图片的路径一次性读到内存中。

大致的解决思路为:

将上万张图片的路径一次性读到内存中,自己实现一个分批读取函数,在该函数中根据自己的内存情况设置读取图片,只把这一批图片读入内存中,然后交给模型,模型再对这一批图片进行分批训练,因为内存一般大于等于显存,所以内存的批次大小和显存的批次大小通常不相同。

下面代码分别介绍Tensorflow和Keras分批将数据读到内存中的关键函数。Tensorflow对初学者不太友好,所以我个人现阶段更习惯用它的高层API Keras来做相关项目,下面的TF实现是之前不会用Keras分批读时候参考的一些列资料,在模型训练上仍使用Keras,只有分批读取用了TF的API。

Tensorlow

在input.py里写get_batch函数。

def get_batch(X_train, y_train, img_w, img_h, color_type, batch_size, capacity):
  '''
  Args:
    X_train: train img path list
    y_train: train labels list
    img_w: image width
    img_h: image height
    batch_size: batch size
    capacity: the maximum elements in queue
  Returns:
    X_train_batch: 4D tensor [batch_size, width, height, chanel],\
            dtype=tf.float32
    y_train_batch: 1D tensor [batch_size], dtype=int32
  '''
  X_train = tf.cast(X_train, tf.string)

  y_train = tf.cast(y_train, tf.int32)

  # make an input queue
  input_queue = tf.train.slice_input_producer([X_train, y_train])

  y_train = input_queue[1]
  X_train_contents = tf.read_file(input_queue[0])
  X_train = tf.image.decode_jpeg(X_train_contents, channels=color_type)

  X_train = tf.image.resize_images(X_train, [img_h, img_w],
                   tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  X_train_batch, y_train_batch = tf.train.batch([X_train, y_train],
                         batch_size=batch_size,
                         num_threads=64,
                         capacity=capacity)
  y_train_batch = tf.one_hot(y_train_batch, 10)

  return X_train_batch, y_train_batch

在train.py文件中训练(下面不是纯TF代码,model.fit是Keras的拟合,用纯TF的替换就好了)。

X_train_batch, y_train_batch = inp.get_batch(X_train, y_train,
                       img_w, img_h, color_type,
                       train_batch_size, capacity)
X_valid_batch, y_valid_batch = inp.get_batch(X_valid, y_valid,
                       img_w, img_h, color_type,
                       valid_batch_size, capacity)
with tf.Session() as sess:

  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(coord=coord)
  try:
    for step in np.arange(max_step):
      if coord.should_stop() :
        break
      X_train, y_train = sess.run([X_train_batch,
                       y_train_batch])
      X_valid, y_valid = sess.run([X_valid_batch,
                       y_valid_batch])

      ckpt_path = 'log/weights-{val_loss:.4f}.hdf5'
      ckpt = tf.keras.callbacks.ModelCheckpoint(ckpt_path,
                           monitor='val_loss',
                           verbose=1,
                           save_best_only=True,
                           mode='min')
      model.fit(X_train, y_train, batch_size=64,
             epochs=50, verbose=1,
             validation_data=(X_valid, y_valid),
             callbacks=[ckpt])

      del X_train, y_train, X_valid, y_valid

  except tf.errors.OutOfRangeError:
    print('done!')
  finally:
    coord.request_stop()
  coord.join(threads)
  sess.close()

Keras

keras文档中对fit、predict、evaluate这些函数都有一个generator,这个generator就是解决分批问题的。

关键函数:fit_generator

# 读取图片函数
def get_im_cv2(paths, img_rows, img_cols, color_type=1, normalize=True):
  '''
  参数:
    paths:要读取的图片路径列表
    img_rows:图片行
    img_cols:图片列
    color_type:图片颜色通道
  返回:
    imgs: 图片数组
  '''
  # Load as grayscale
  imgs = []
  for path in paths:
    if color_type == 1:
      img = cv2.imread(path, 0)
    elif color_type == 3:
      img = cv2.imread(path)
    # Reduce size
    resized = cv2.resize(img, (img_cols, img_rows))
    if normalize:
      resized = resized.astype('float32')
      resized /= 127.5
      resized -= 1. 

    imgs.append(resized)

  return np.array(imgs).reshape(len(paths), img_rows, img_cols, color_type)

获取批次函数,其实就是一个generator

def get_train_batch(X_train, y_train, batch_size, img_w, img_h, color_type, is_argumentation):
  '''
  参数:
    X_train:所有图片路径列表
    y_train: 所有图片对应的标签列表
    batch_size:批次
    img_w:图片宽
    img_h:图片高
    color_type:图片类型
    is_argumentation:是否需要数据增强
  返回:
    一个generator,x: 获取的批次图片 y: 获取的图片对应的标签
  '''
  while 1:
    for i in range(0, len(X_train), batch_size):
      x = get_im_cv2(X_train[i:i+batch_size], img_w, img_h, color_type)
      y = y_train[i:i+batch_size]
      if is_argumentation:
        # 数据增强
        x, y = img_augmentation(x, y)
      # 最重要的就是这个yield,它代表返回,返回以后循环还是会继续,然后再返回。就比如有一个机器一直在作累加运算,但是会把每次累加中间结果告诉你一样,直到把所有数加完
      yield({'input': x}, {'output': y})

训练函数

result = model.fit_generator(generator=get_train_batch(X_train, y_train, train_batch_size, img_w, img_h, color_type, True),
     steps_per_epoch=1351,
     epochs=50, verbose=1,
     validation_data=get_train_batch(X_valid, y_valid, valid_batch_size,img_w, img_h, color_type, False),
     validation_steps=52,
     callbacks=[ckpt, early_stop],
     max_queue_size=capacity,
     workers=1)

就是这么简单。但是当初从0到1的过程很难熬,每天都没有进展,没有头绪,急躁占据了思维的大部,熬过了这个阶段,就会一切顺利,不是运气,而是踩过的从0到1的每个脚印累积的灵感的爆发,从0到1的脚印越多,后面的路越顺利。

以上这篇完美解决TensorFlow和Keras大数据量内存溢出的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • windows安装TensorFlow和Keras遇到的问题及其解决方法

    安装TensorFlow在Windows上,真是让我心力交瘁,想死的心都有了,在Windows上做开发真的让人发狂. 首先说一下我的经历,本来也就是起初,网上说python3.7不支持TensorFlow环境,而且使用Anaconda最好,所以我将我之前Windows上所有的python环境卸载掉!!!,对没错,是所有,包括Anaconda环境,python环境,pycharm环境也卸载掉了.而且我丧心病狂的在电脑上找几乎所有关于python的字眼,全部删除掉,统统不留.只是为了铁了心在Wind

  • 解决Keras 与 Tensorflow 版本之间的兼容性问题

    在利用Keras进行实验的时候,后端为Tensorflow,出现了以下问题: 1. 服务器端激活Anaconda环境跑程序时,实验结果很差. 环境:tensorflow 1.4.0,keras 2.1.5 2. 服务器端未激活Anaconda环境跑程序时,实验结果回到正常值. 环境:tensorflow 1.7.0,keras 2.0.8 3. 自己PC端跑相同程序时,实验结果回到正常值. 环境:tensorflow 1.6.0,keras 2.1.5 怀疑实验结果的异常性是由于Keras和Te

  • 详解解决Python memory error的问题(四种解决方案)

    昨天在用用Pycharm读取一个200+M的CSV的过程中,竟然出现了Memory Error!简直让我怀疑自己买了个假电脑,毕竟是8G内存i7处理器,一度怀疑自己装了假的内存条....下面说一下几个解题步骤....一般就是用下面这些方法了,按顺序试试. 一.逐行读取 如果你用pd.read_csv来读文件,会一次性把数据都读到内存里来,导致内存爆掉,那么一个想法就是一行一行地读它,代码如下: data = [] with open(path, 'r',encoding='gbk',errors

  • Tensorflow与Keras自适应使用显存方式

    Tensorflow支持基于cuda内核与cudnn的GPU加速,Keras出现较晚,为Tensorflow的高层框架,由于Keras使用的方便性与很好的延展性,之后更是作为Tensorflow的官方指定第三方支持开源框架. 但两者在使用GPU时都有一个特点,就是默认为全占满模式.在训练的情况下,特别是分步训练时会导致显存溢出,导致程序崩溃. 可以使用自适应配置来调整显存的使用情况. 一.Tensorflow 1.指定显卡 代码中加入 import os os.environ["CUDA_VIS

  • 完美解决TensorFlow和Keras大数据量内存溢出的问题

    内存溢出问题是参加kaggle比赛或者做大数据量实验的第一个拦路虎. 以前做的练手小项目导致新手产生一个惯性思维--读取训练集图片的时候把所有图读到内存中,然后分批训练. 其实这是有问题的,很容易导致OOM.现在内存一般16G,而训练集图片通常是上万张,而且RGB图,还很大,VGG16的图片一般是224x224x3,上万张图片,16G内存根本不够用.这时候又会想起--设置batch,但是那个batch的输入参数却又是图片,它只是把传进去的图片分批送到显卡,而我OOM的地方恰是那个"传进去&quo

  • 详解Mysql数据库平滑扩容解决高并发和大数据量问题

    目录 1 停机方案 2 停写方案 3 平滑扩容之双写方案(中小型数据) 4 平滑扩容之2N方案大数据量问题解决 4.1 扩容问题 4.2 解决方案 4.3 双主架构思想 4.4 环境部署 5 数据库秒级平滑2N扩容实践 5.1 新增数据库VIP 5.2 应用服务增加动态数据源 5.3 解除原双主同步 5.4 安装MariaDB扩容服务器 5.5 增加KeepAlived服务实现高可用 5.6 清理数据并验证 1 停机方案 发布公告 停止服务 离线数据迁移(拆分,重新分配数据) 数据校验 更改配置

  • 解决Java导入excel大量数据出现内存溢出的问题

    问题:系统要求导入40万条excel数据,采用poi方式,服务器出现内存溢出情况. 解决方法:由于HSSFWorkbook workbook = new HSSFWorkbook(path)一次性将excel load到内存中导致内存不够. 故采用读取csv格式.由于csv的数据以x1,x2,x3形成,类似读取txt文档. private BufferedReader bReader; /** * 执行文件入口 */ public void execute() { try { if(!path.

  • 完美解决因数据库一次查询数据量过大导致的内存溢出问题

    刚开始接触项目的实习生,积累经验,欢迎交流 之前做项目,遇到过一次查询数据量过大而导致的内存溢出问题,找了很多办法一直未能实际解决问题, 今天又遇到了,经过前辈的指导,终于解决了问题!! 不过此方法只在DBug启动下有效 以上这篇完美解决因数据库一次查询数据量过大导致的内存溢出问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们.

  • 分布式难题ElasticSearch解决大数据量检索面试

    目录 引言 1.面试官: 我看你简历有写项目里有使用了ES,哪些场景用到了ES? 2.面试官: 那使用了ES后结果如何? 3.面试官: 关于ES的一些概念名字你了解多少?如索引,文档,倒排索引这些东西你是怎么理解的? 最重要的倒排索引 总结 关于ES的特性是使用场景概括: 引言 如果你的项目里有超过千万上亿级别的数据,且数据日增量较大需要高性能检索时,如订单数据,你该怎么办? 作为面试官,你需要找一个能解决这个问题的人!为应聘者,你该如何回答面试官这个问题? 你可以了解下使用搜索引擎框架,Ela

  • 针对Sqlserver大数据量插入速度慢或丢失数据的解决方法

    我的设备上每秒将2000条数据插入数据库,2个设备总共4000条,当在程序里面直接用insert语句插入时,两个设备同时插入大概总共能插入约2800条左右,数据丢失约1200条左右,测试了很多方法,整理出了两种效果比较明显的解决办法: 方法一:使用Sql Server函数: 1.将数据组合成字串,使用函数将数据插入内存表,后将内存表数据复制到要插入的表. 2.组合成的字符换格式:'111|222|333|456,7894,7458|0|1|2014-01-01 12:15:16;1111|222

  • 解决SpringBoot框架因post数据量过大没反应问题(踩坑)

    此处网上最多的做法是需要修改tomcat的参数配置大致如下: <Connector port="8080" protocol="HTTP/1.1" connectionTimeout="2000" redirectPort="8443" URIEncoding="UTF-8" maxThreads="3000" compression="on" compress

  • 解决Mybatis 大数据量的批量insert问题

    前言 通过Mybatis做7000+数据量的批量插入的时候报错了,error log如下: , ('G61010352', '610103199208291214', '学生52', 'G61010350', '610103199109920192', '学生50', '07', '01', '0104', ' ', , ' ', ' ', current_timestamp, current_timestamp ) 被中止,呼叫 getNextException 以取得原因. at org.p

  • php 大数据量及海量数据处理算法总结

    下面的方法是我对海量数据的处理方法进行了一个一般性的总结,当然这些方法可能并不能完全覆盖所有的问题,但是这样的一些方法也基本可以处理绝大多数遇到的问题.下面的一些问题基本直接来源于公司的面试笔试题目,方法不一定最优,如果你有更好的处理方法,欢迎与我讨论. 1.Bloom filter 适用范围:可以用来实现数据字典,进行数据的判重,或者集合求交集 基本原理及要点: 对于原理来说很简单,位数组+k个独立hash函数.将hash函数对应的值的位数组置1,查找时如果发现所有hash函数对应位都是1说明

  • 大数据量高并发的数据库优化详解

    如果不能设计一个合理的数据库模型,不仅会增加客户端和服务器段程序的编程和维护的难度,而且将会影响系统实际运行的性能.所以,在一个系统开始实施之前,完备的数据库模型的设计是必须的. 一.数据库结构的设计 在一个系统分析.设计阶段,因为数据量较小,负荷较低.我们往往只注意到功能的实现,而很难注意到性能的薄弱之处,等到系统投入实际运行一段时间后,才发现系统的性能在降低,这时再来考虑提高系统性能则要花费更多的人力物力,而整个系统也不可避免的形成了一个打补丁工程. 所以在考虑整个系统的流程的时候,我们必须

随机推荐