使用TFRecord存取多个数据案例

TensorFlow提供了一种统一的格式来存储数据,就是TFRecord,它可以统一不同的原始数据格式,并且更加有效地管理不同的属性。

TFRecord格式

TFRecord文件中的数据都是用tf.train.Example Protocol Buffer的格式来存储的,tf.train.Example可以被定义为:

message Example{
  Features features = 1
}

message Features{
  map<string, Feature> feature = 1
}

message Feature{
  oneof kind{
    BytesList bytes_list = 1
    FloatList float_list = 1
    Int64List int64_list = 1
  }
}

可以看出Example是一个嵌套的数据结构,其中属性名称可以为一个字符串,其取值可以是字符串BytesList、实数列表FloatList或整数列表Int64List。

将数据转化为TFRecord格式

以下代码是将MNIST输入数据转化为TFRecord格式:

# -*- coding: utf-8 -*-

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np

# 生成整数型的属性
def _int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# 生成浮点型的属性
def _float_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
#若想保存为数组,则要改成value=value即可

# 生成字符串型的属性
def _bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

mnist = input_data.read_data_sets("/tensorflow_google", dtype=tf.uint8, one_hot=True)
images = mnist.train.images
# 训练数据所对应的正确答案,可以作为一个属性保存在TFRecord中
labels = mnist.train.labels
# 训练数据的图像分辨率,这可以作为Example中的一个属性
pixels = images.shape[1]
num_examples = mnist.train.num_examples

# 输出TFRecord文件的地址
filename = "/tensorflow_google/mnist_output.tfrecords"
# 创建一个writer来写TFRecord文件
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
  # 将图像矩阵转换成一个字符串
  image_raw = images[index].tostring()
  # 将一个样例转化为Example Protocol Buffer, 并将所有的信息写入这个数据结构
  example = tf.train.Example(features=tf.train.Features(feature={
    'pixels': _int64_feature(pixels),
    'label': _int64_feature(np.argmax(labels[index])),
    'image_raw': _bytes_feature(image_raw)}))

  # 将一个Example写入TFRecord文件
  writer.write(example.SerializeToString())
writer.close()

本程序将MNIST数据集中所有的训练数据存储到了一个TFRecord文件中,若数据量较大,也可以存入多个文件。

从TFRecord文件中读取数据

以下代码可以从上面代码中的TFRecord中读取单个或多个训练数据:

# -*- coding: utf-8 -*-
import tensorflow as tf

# 创建一个reader来读取TFRecord文件中的样例
reader = tf.TFRecordReader()
# 创建一个队列来维护输入文件列表
filename_queue = tf.train.string_input_producer(["/Users/gaoyue/文档/Program/tensorflow_google/chapter7"
                         "/mnist_output.tfrecords"])

# 从文件中读出一个样例,也可以使用read_up_to函数一次性读取多个样例
# _, serialized_example = reader.read(filename_queue)
_, serialized_example = reader.read_up_to(filename_queue, 6) #读取6个样例
# 解析读入的一个样例,如果需要解析多个样例,可以用parse_example函数
# features = tf.parse_single_example(serialized_example, features={
# 解析多个样例
features = tf.parse_example(serialized_example, features={
  # TensorFlow提供两种不同的属性解析方法
  # 第一种是tf.FixedLenFeature,得到的解析结果为Tensor
  # 第二种是tf.VarLenFeature,得到的解析结果为SparseTensor,用于处理稀疏数据
  # 解析数据的格式需要与写入数据的格式一致
  'image_raw': tf.FixedLenFeature([], tf.string),
  'pixels': tf.FixedLenFeature([], tf.int64),
  'label': tf.FixedLenFeature([], tf.int64),
})

# tf.decode_raw可以将字符串解析成图像对应的像素数组
images = tf.decode_raw(features['image_raw'], tf.uint8)
labels = tf.cast(features['label'], tf.int32)
pixels = tf.cast(features['pixels'], tf.int32)

sess = tf.Session()
# 启动多线程处理输入数据
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

# 每次运行可以读取TFRecord中的一个样例,当所有样例都读完之后,会重头读取
# for i in range(10):
#   image, label, pixel = sess.run([images, labels, pixels])
#   # print(image, label, pixel)
#   print(label, pixel)

# 读取TFRecord中的前6个样例,若加入循环,则会每次从上次输出的地方继续顺序读6个样例
image, label, pixel = sess.run([images, labels, pixels])
print(label, pixel)

sess.close()

>> [7 3 4 6 1 8] [784 784 784 784 784 784]

输出结果显示,从TFRecord文件中顺序读出前6个样例。

以上这篇使用TFRecord存取多个数据案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • Tensorflow之构建自己的图片数据集TFrecords的方法

    学习谷歌的深度学习终于有点眉目了,给大家分享我的Tensorflow学习历程. tensorflow的官方中文文档比较生涩,数据集一直采用的MNIST二进制数据集.并没有过多讲述怎么构建自己的图片数据集tfrecords. 流程是:制作数据集-读取数据集--加入队列 先贴完整的代码: #encoding=utf-8 import os import tensorflow as tf from PIL import Image cwd = os.getcwd() classes = {'test'

  • Tensorflow 实现将图像与标签数据转化为tfRecord文件

    tensorflow中如果要对神经网络模型进行训练,需要把训练数据转换为tfrecord格式才能被读取,tensorflow的model文件里直接提供了相应的脚本文件在下面的文件夹中: cd tensorflow/models/research/object_detection/dataset_tools 其中包括: 1.create_coco_tf_record.py:注意,这个代码需要解析json格式的标签文件 2.create_pascal_tf_record.py:注意,这个代码需要解析

  • TFRecord格式存储数据与队列读取实例

    Tensor Flow官方网站上提供三种读取数据的方法 1. 预加载数据:在Tensor Flow图中定义常量或变量来保存所有数据,将数据直接嵌到数据图中,当训练数据较大时,很消耗内存. 如 x1=tf.constant([0,1]) x2=tf.constant([1,0]) y=tf.add(x1,x2) 2.填充数据:使用sess.run()的feed_dict参数,将Python产生的数据填充到后端,之前的MNIST数据集就是通过这种方法.也有消耗内存,数据类型转换耗时的缺点. 3. 从

  • Tensorflow中使用tfrecord方式读取数据的方法

    前言 本博客默认读者对神经网络与Tensorflow有一定了解,对其中的一些术语不再做具体解释.并且本博客主要以图片数据为例进行介绍,如有错误,敬请斧正. 使用Tensorflow训练神经网络时,我们可以用多种方式来读取自己的数据.如果数据集比较小,而且内存足够大,可以选择直接将所有数据读进内存,然后每次取一个batch的数据出来.如果数据较多,可以每次直接从硬盘中进行读取,不过这种方式的读取效率就比较低了.此篇博客就主要讲一下Tensorflow官方推荐的一种较为高效的数据读取方式--tfre

  • tensorflow入门:TFRecordDataset变长数据的batch读取详解

    在上一篇文章tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用里,讲到了使用如何使用tf.data.TFRecordDatase来对tfrecord文件进行batch读取,即使用dataset的batch方法进行:但如果每条数据的长度不一样(常见于语音.视频.NLP等领域),则不能直接用batch方法获取数据,这时则有两个解决办法: 1.在把数据写入tfrecord时,先把数据pad到统一的长度再写入tfrecord:这个方法的问题在于:若是有大量

  • Tensorflow使用tfrecord输入数据格式

    Tensorflow 提供了一种统一的格式来存储数据,这个格式就是TFRecord,上一篇文章中所提到的方法当数据的来源更复杂,每个样例中的信息更丰富的时候就很难有效的记录输入数据中的信息了,于是Tensorflow提供了TFRecord来统一存储数据,接下来我们就来介绍如何使用TFRecord来同意输入数据的格式. 1. TFRecord格式介绍 TFRecord文件中的数据是通过tf.train.Example Protocol Buffer的格式存储的,下面是tf.train.Exampl

  • 将自己的数据集制作成TFRecord格式教程

    在使用TensorFlow训练神经网络时,首先面临的问题是:网络的输入 此篇文章,教大家将自己的数据集制作成TFRecord格式,feed进网络,除了TFRecord格式,TensorFlow也支持其他格 式的数据,此处就不再介绍了.建议大家使用TFRecord格式,在后面可以通过api进行多线程的读取文件队列. 1. 原本的数据集 此时,我有两类图片,分别是xiansu100,xiansu60,每一类中有10张图片. 2.制作成TFRecord格式 tfrecord会根据你选择输入文件的类,自

  • 使用TFRecord存取多个数据案例

    TensorFlow提供了一种统一的格式来存储数据,就是TFRecord,它可以统一不同的原始数据格式,并且更加有效地管理不同的属性. TFRecord格式 TFRecord文件中的数据都是用tf.train.Example Protocol Buffer的格式来存储的,tf.train.Example可以被定义为: message Example{ Features features = 1 } message Features{ map<string, Feature> feature =

  • js实现股票实时刷新数据案例

    近来学习炒股,免不了上班时间看盘,总不能光明正大的用电脑看行情,一直盯着手机影响也不好,容易引起"关注". 所以就想自己做一个网页来达到看盘的目的,一个只显示几个关键数字的网页肯定不会引起怀疑.有想法了,就开始实现吧. 准备工作: 1.数据来源 2.网页数据显示 先帖出来源码,后面讲解 <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/T

  • VBA处理数据与Python Pandas处理数据案例比较分析

    需求: 现有一个 csv文件,包含'CNUM'和'COMPANY'两列,数据里包含空行,且有内容重复的行数据. 要求: 1)去掉空行: 2)重复行数据只保留一行有效数据: 3)修改'COMPANY'列的名称为'Company_New': 4)并在其后增加六列,分别为'C_col','D_col','E_col','F_col','G_col','H_col'. 一,使用 Python Pandas来处理: import pandas as pd import numpy as np from p

  • Python 存取npy格式数据实例

    数据处理的时候主要通过两个函数 (1):np.save("test.npy",数据结构) ----存数据 (2):data =np.load('test.npy") ----取数据 给2个例子如下(存列表) 1. z = [[[1, 2, 3], ['w']], [[1, 2, 3], ['w']]] np.save('test.npy', z) x = np.load('test.npy') x: ->array([[list([1, 2, 3]), list(['w

  • Java使用easyExcel导出excel数据案例

    easyExcel简介: Java领域解析.生成Excel比较有名的框架有Apache poi.jxl等.但他们都存在一个严重的问题就是非常的耗内存.如果你的系统并发量不大的话可能还行,但是一旦并发上来后一定会OOM或者JVM频繁的full gc. easyExcel是阿里巴巴开源的一个excel处理框架,以使用简单.节省内存著称. easyExcel采用一行一行的解析模式,并将一行的解析结果以观察者的模式通知处理 easyExcel能大大减少占用内存的主要原因是在解析Excel时没有将文件数据

  • postgresql 删除重复数据案例详解

    1.建表 /* Navicat Premium Data Transfer Source Server : localhost Source Server Type : PostgreSQL Source Server Version : 110012 Source Host : localhost:5432 Source Catalog : postgres Source Schema : public Target Server Type : PostgreSQL Target Server

  • Ajax responseText解析json数据案例详解

    解决ajax处理服务器端返回结果responseText中是JSON的数据. 第一,json格式的文件内容如下: { "city":"ShangHai", "telephone":"123456789" } 第二,服务器端返回的json数据就是上述的内容在responseText中,现在要取出来,方法有两种: 方法1: var json=JSON.parse(request.responseText); alert(json.

  • SQL Server批量插入数据案例详解

    在SQL Server 中插入一条数据使用Insert语句,但是如果想要批量插入一堆数据的话,循环使用Insert不仅效率低,而且会导致SQL一系统性能问题.下面介绍SQL Server支持的两种批量数据插入方法:Bulk和表值参数(Table-Valued Parameters),高效插入数据. 新建数据库: --Create DataBase create database BulkTestDB; go use BulkTestDB; go --Create Table Create tab

  • Vue之使用mockjs生成模拟数据案例详解

    目录 在项目中安装mockjs 在Vue项目中使用mockjs的基本流程 Mock语法规范 数据模板定义规范(Data Template Definition,DTD) 数据占位符定义规范(Data Placeholder Definition,DPD) Mock.mock() Mock.Random() 在项目中安装mockjs 在项目目录下执行以下安装命令 npm install mockjs --save 在Vue项目中使用mockjs的基本流程 安装完成后,在项目src/utils目录下

  • Python爬虫采集Tripadvisor数据案例实现

    目录 前言 第三方库 开发环境 开始代码 请求数据 2. 获取数据(网页源代码) 3. 解析数据(提取我们想要的数据内容 详情页链接) 4. 发送请求(访问所有的详情页链接) 获取数据 5. 解析数据 6.保存数据 7.得到数据 前言 Tripadvisor是全球领先的旅游网站,主要提供来自全球旅行者的点评和建议,全面覆盖全球的酒店.景点.餐厅.航空公司 ,以及旅行规划和酒店.景点.餐厅预订功能.Tripadvisor及旗下网站在全球49个市场设有分站,月均独立访问量达4.15亿. 第三方库 r

随机推荐