tensorflow入门:TFRecordDataset变长数据的batch读取详解

在上一篇文章tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用里,讲到了使用如何使用tf.data.TFRecordDatase来对tfrecord文件进行batch读取,即使用dataset的batch方法进行;但如果每条数据的长度不一样(常见于语音、视频、NLP等领域),则不能直接用batch方法获取数据,这时则有两个解决办法:

1.在把数据写入tfrecord时,先把数据pad到统一的长度再写入tfrecord;这个方法的问题在于:若是有大量数据的长度都远远小于最大长度,则会造成存储空间的大量浪费。

2.使用dataset中的padded_batch方法来进行,参数padded_shapes #指明每条记录中各成员要pad成的形状,成员若是scalar,则用[],若是list,则用[mx_length],若是array,则用[d1,...,dn],假如各成员的顺序是scalar数据、list数据、array数据,则padded_shapes=([], [mx_length], [d1,...,dn]);该方法的函数说明如下:

padded_batch(
 batch_size,
 padded_shapes,
 padding_values=None #默认使用各类型数据的默认值,一般使用时可忽略该项
)

使用mnist数据来举例说明,首先在把mnist写入tfrecord之前,把mnist数据进行更改,以使得每个mnist图像的大小不等,如下:

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets

mnist = read_data_sets("MNIST_data/", one_hot=True)

def get_tfrecords_example(feature, label):
 tfrecords_features = {}
 feat_shape = feature.shape
 tfrecords_features['feature'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature))
 tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
 tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
 return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))

def make_tfrecord(data, outf_nm='mnist-train'):
 feats, labels = data
 outf_nm += '.tfrecord'
 tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
 ndatas = len(labels)
 print(feats[0].dtype, feats[0].shape, ndatas)
 assert len(labels[0]) > 1
 for inx in range(ndatas):
 ed = random.randint(0,3) #随机丢掉几个数据点,以使长度不等
 exmp = get_tfrecords_example(feats[inx][:-ed], labels[inx])
 exmp_serial = exmp.SerializeToString()
 tfrecord_wrt.write(exmp_serial)
 tfrecord_wrt.close()

import random
nDatas = len(mnist.train.labels)
inx_lst = range(nDatas)
random.shuffle(inx_lst)
random.shuffle(inx_lst)
ntrains = int(0.85*nDatas)

# make training set
data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
 [mnist.train.labels[i] for i in inx_lst[:ntrains]])
make_tfrecord(data, outf_nm='mnist-train')

# make validation set
data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
 [mnist.train.labels[i] for i in inx_lst[ntrains:]])
make_tfrecord(data, outf_nm='mnist-val')

# make test set
data = (mnist.test.images, mnist.test.labels)
make_tfrecord(data, outf_nm='mnist-test')

用dataset加载批量数据,在解析数据时用到tf.VarLenFeature(tf.datatype),而非tf.FixedLenFeature([], tf.datatype)},且要配合tf.sparse_tensor_to_dense函数使用,如下:

import tensorflow as tf

train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]

def parse_exmp(serial_exmp):
 feats = tf.parse_single_example(serial_exmp, features={'feature':tf.VarLenFeature(tf.float32),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
 image = tf.sparse_tensor_to_dense(feats['feature']) #使用VarLenFeature读入的是一个sparse_tensor,用该函数进行转换
 label = tf.reshape(feats['label'],[2,5]) #把label变成[2,5],以说明array数据如何padding
 shape = tf.cast(feats['shape'], tf.int32)
 return image, label, shape

def get_dataset(fname):
 dataset = tf.data.TFRecordDataset(fname)
 return dataset.map(parse_exmp) # use padded_batch method if padding needed

epochs = 16
batch_size = 50
padded_shapes = ([784],[3,5],[]) #把image pad至784,把label pad至[3,5],shape是一个scalar,不输入数字
# training dataset
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).padded_batch(batch_size, padded_shapes=padded_shapes)

以上这篇tensorflow入门:TFRecordDataset变长数据的batch读取详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • Tensorflow中使用tfrecord方式读取数据的方法

    前言 本博客默认读者对神经网络与Tensorflow有一定了解,对其中的一些术语不再做具体解释.并且本博客主要以图片数据为例进行介绍,如有错误,敬请斧正. 使用Tensorflow训练神经网络时,我们可以用多种方式来读取自己的数据.如果数据集比较小,而且内存足够大,可以选择直接将所有数据读进内存,然后每次取一个batch的数据出来.如果数据较多,可以每次直接从硬盘中进行读取,不过这种方式的读取效率就比较低了.此篇博客就主要讲一下Tensorflow官方推荐的一种较为高效的数据读取方式--tfre

  • tensorflow中tf.slice和tf.gather切片函数的使用

    tf.slice(input_, begin, size, name=None):按照指定的下标范围抽取连续区域的子集 tf.gather(params, indices, validate_indices=None, name=None):按照指定的下标集合从axis=0中抽取子集,适合抽取不连续区域的子集 输出: input = [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]] tf.slice(

  • tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用

    1.创建tfrecord tfrecord支持写入三种格式的数据:string,int64,float32,以列表的形式分别通过tf.train.BytesList.tf.train.Int64List.tf.train.FloatList写入tf.train.Feature,如下所示: tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()])) #feature一般是多维数组,要先转为list tf.t

  • tensorflow入门:TFRecordDataset变长数据的batch读取详解

    在上一篇文章tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用里,讲到了使用如何使用tf.data.TFRecordDatase来对tfrecord文件进行batch读取,即使用dataset的batch方法进行:但如果每条数据的长度不一样(常见于语音.视频.NLP等领域),则不能直接用batch方法获取数据,这时则有两个解决办法: 1.在把数据写入tfrecord时,先把数据pad到统一的长度再写入tfrecord:这个方法的问题在于:若是有大量

  • 深入SQL Server中定长char(n)与变长varchar(n)的区别详解

    char(n)是定长格式,格式为char(n)的字段固定占用n个字符宽度,如果实际存放的数据长度超过n将被截取多出部分,如果长度小于n就用空字符填充. varchar(n)是变长格式,这种格式的字段根据实际数据长度分配空间,不浪费对于的空间,但是搜索数据的速度会麻烦一点. 一般地说,只要一个表有一个字段定义为varchar(n)类型,那么其余用char(n)定义的字段实际上也是varchar(n)类型. 如果你的长度本身不长,比如就3-10个字符,那么使用char(n)格式效率比较高,搜索速度快

  • TensorFlow人工智能学习按索引取数据及维度变换详解

    目录 一.按索引取数据 ①tf.gather() ②tf.gather_nd ③tf.boolean_mask 二.维度变换 ①tf.reshape() ②tf.transpose() ③tf.expand_dims() ④tf.squeeze() 一.按索引取数据 ①tf.gather() 输入参数:数据.维度.索引 例:设数据是[4,35,8],4个班级,每个班级35个学生,每个学生8门课成绩. 则下面In [49]的意思是,全部四个班级,每个班级取编号为2,3,7,9,16的学生,每个学生

  • Ajax 入门之 GET 与 POST 的不同处详解

    在之前的随笔中,本着怀旧的态度总结了一篇 兼容不同浏览器 建立XHR对象的方法: 在建立好XHR对象之后,客户端需要做的就是,将数据以某种方式传递到服务器,以获得相应的响应,在这里,  Ajax技术总结的第二季,我将重点阐述 提交数据的两种方式. 在这之前需要了解一下我们的HTTP传输协议: HTTP 的工作方式是客户机与服务器之间的请求-应答协议. 举例:客户端(浏览器)向服务器提交 HTTP 请求:服务器向客户端返回响应.响应包含关于请求的状态信息以及可能被请求的内容.而想要基于HTTP协议

  • Oracle数据操作和控制语言详解

    正在看的ORACLE教程是:Oracle数据操作和控制语言详解.SQL语言共分为四大类:数据查询语言DQL,数据操纵语言DML, 数据定义语言DDL,数据控制语言DCL.其中用于定义数据的结构,比如 创建.修改或者删除数据库:DCL用于定义数据库用户的权限:在这篇文章中我将详细讲述这两种语言在Oracle中的使用方法. DML语言 DML是SQL的一个子集,主要用于修改数据,下表列出了ORACLE支持的DML语句. 插入数据 INSERT语句常常用于向表中插入行,行中可以有特殊数据字段,或者可以

  • 对Tensorflow中权值和feature map的可视化详解

    前言 Tensorflow中可以使用tensorboard这个强大的工具对计算图.loss.网络参数等进行可视化.本文并不涉及对tensorboard使用的介绍,而是旨在说明如何通过代码对网络权值和feature map做更灵活的处理.显示和存储.本文的相关代码主要参考了github上的一个小项目,但是对其进行了改进. 原项目地址为(https://github.com/grishasergei/conviz). 本文将从以下两个方面进行介绍: 卷积知识补充 网络权值和feature map的可

  • TensorFlow人工智能学习张量及高阶操作示例详解

    目录 一.张量裁剪 1.tf.maximum/minimum/clip_by_value() 2.tf.clip_by_norm() 二.张量排序 1.tf.sort/argsort() 2.tf.math.topk() 三.TensorFlow高阶操作 1.tf.where() 2.tf.scatter_nd() 3.tf.meshgrid() 一.张量裁剪 1.tf.maximum/minimum/clip_by_value() 该方法按数值裁剪,传入tensor和阈值,maximum是把数

  • C++入门教程之内联函数与extern "C"详解

    目录 一.   内联函数 1.概念及分析 2.特性 3.宏 二.   extern “C” 1.C++程序 2.C程序 总结 一.   内联函数 1.概念及分析 以inline修饰的函数叫做内联函数,编译时C++编译器会在调用内联函数的地方展开,没有函数调用建立栈帧的开销,内联函数提升程序运行的效率. int Add(int a, int b) { int c = a + b; return c; } int main() { int ret= Add(1, 2); return 0; } 在我

  • mysql数据存储过程参数实例详解

    MySQL 存储过程参数有三种类型:in.out.inout.它们各有什么作用和特点呢? 一.MySQL 存储过程参数(in) MySQL 存储过程 "in" 参数:跟 C 语言的函数参数的值传递类似, MySQL 存储过程内部可能会修改此参数,但对 in 类型参数的修改,对调用者(caller)来说是不可见的(not visible). drop procedure if exists pr_param_in; create procedure pr_param_in ( in id

  • IOS 数据库升级数据迁移的实例详解

    IOS 数据库升级数据迁移的实例详解 概要: 很久以前就遇到过数据库版本升级的引用场景,当时的做法是简单的删除旧的数据库文件,重建数据库和表结构,这种暴力升级的方式会导致旧的数据的丢失,现在看来这并不不是一个优雅的解决方案,现在一个新的项目中又使用到了数据库,我不得不重新考虑这个问题,我希望用一种比较优雅的方式去解决这个问题,以后我们还会遇到类似的场景,我们都想做的更好不是吗? 理想的情况是:数据库升级,表结构.主键和约束有变化,新的表结构建立之后会自动的从旧的表检索数据,相同的字段进行映射迁移

随机推荐