tensorflow TFRecords文件的生成和读取的方法

TensorFlow提供了TFRecords的格式来统一存储数据,理论上,TFRecords可以存储任何形式的数据。

TFRecords文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。以下的代码给出了tf.train.Example的定义。

message Example {
  Features features = 1;
};
message Features {
  map<string, Feature> feature = 1;
};
message Feature {
  oneof kind {
  BytesList bytes_list = 1;
  FloatList float_list = 2;
  Int64List int64_list = 3;
}
}; 

下面将介绍如何生成和读取tfrecords文件:

首先介绍tfrecords文件的生成,直接上代码:

from random import shuffle
import numpy as np
import glob
import tensorflow as tf
import cv2
import sys
import os 

# 因为我装的是CPU版本的,运行起来会有'warning',解决方法入下,眼不见为净~
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 

shuffle_data = True
image_path = '/path/to/image/*.jpg' 

# 取得该路径下所有图片的路径,type(addrs)= list
addrs = glob.glob(image_path)
# 标签数据的获得具体情况具体分析,type(labels)= list
labels = ... 

# 这里是打乱数据的顺序
if shuffle_data:
  c = list(zip(addrs, labels))
  shuffle(c)
  addrs, labels = zip(*c) 

# 按需分割数据集
train_addrs = addrs[0:int(0.7*len(addrs))]
train_labels = labels[0:int(0.7*len(labels))] 

val_addrs = addrs[int(0.7*len(addrs)):int(0.9*len(addrs))]
val_labels = labels[int(0.7*len(labels)):int(0.9*len(labels))] 

test_addrs = addrs[int(0.9*len(addrs)):]
test_labels = labels[int(0.9*len(labels)):] 

# 上面不是获得了image的地址么,下面这个函数就是根据地址获取图片
def load_image(addr): # A function to Load image
  img = cv2.imread(addr)
  img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC)
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  # 这里/255是为了将像素值归一化到[0,1]
  img = img / 255.
  img = img.astype(np.float32)
  return img 

# 将数据转化成对应的属性
def _int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 

def _bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 

def _float_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) 

# 下面这段就开始把数据写入TFRecods文件 

train_filename = '/path/to/train.tfrecords' # 输出文件地址 

# 创建一个writer来写 TFRecords 文件
writer = tf.python_io.TFRecordWriter(train_filename) 

for i in range(len(train_addrs)):
  # 这是写入操作可视化处理
  if not i % 1000:
    print('Train data: {}/{}'.format(i, len(train_addrs)))
    sys.stdout.flush()
  # 加载图片
  img = load_image(train_addrs[i]) 

  label = train_labels[i] 

  # 创建一个属性(feature)
  feature = {'train/label': _int64_feature(label),
        'train/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))} 

  # 创建一个 example protocol buffer
  example = tf.train.Example(features=tf.train.Features(feature=feature)) 

  # 将上面的example protocol buffer写入文件
  writer.write(example.SerializeToString()) 

writer.close()
sys.stdout.flush()

上面只介绍了train.tfrecords文件的生成,其余的validation,test举一反三吧。。

接下来介绍tfrecords文件的读取:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
data_path = 'train.tfrecords' # tfrecords 文件的地址 

with tf.Session() as sess:
  # 先定义feature,这里要和之前创建的时候保持一致
  feature = {
    'train/image': tf.FixedLenFeature([], tf.string),
    'train/label': tf.FixedLenFeature([], tf.int64)
  }
  # 创建一个队列来维护输入文件列表
  filename_queue = tf.train.string_input_producer([data_path], num_epochs=1) 

  # 定义一个 reader ,读取下一个 record
  reader = tf.TFRecordReader()
  _, serialized_example = reader.read(filename_queue) 

  # 解析读入的一个record
  features = tf.parse_single_example(serialized_example, features=feature) 

  # 将字符串解析成图像对应的像素组
  image = tf.decode_raw(features['train/image'], tf.float32) 

  # 将标签转化成int32
  label = tf.cast(features['train/label'], tf.int32) 

  # 这里将图片还原成原来的维度
  image = tf.reshape(image, [224, 224, 3]) 

  # 你还可以进行其他一些预处理.... 

  # 这里是创建顺序随机 batches(函数不懂的自行百度)
  images, labels = tf.train.shuffle_batch([image, label], batch_size=10, capacity=30, min_after_dequeue=10) 

  # 初始化
  init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
  sess.run(init_op) 

  # 启动多线程处理输入数据
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(coord=coord) 

  .... 

  #关闭线程
  coord.request_stop()
  coord.join(threads)
  sess.close() 

好了,就介绍到这里。。,有什么问题可以留言。。大家一起学习。。希望对大家的学习有所帮助,也希望大家多多支持我们。

您可能感兴趣的文章:

  • TensorFlow高效读取数据的方法示例
  • 详解Tensorflow数据读取有三种方式(next_batch)
(0)

相关推荐

  • 详解Tensorflow数据读取有三种方式(next_batch)

    Tensorflow数据读取有三种方式: Preloaded data: 预加载数据 Feeding: Python产生数据,再把数据喂给后端. Reading from file: 从文件中直接读取 这三种有读取方式有什么区别呢? 我们首先要知道TensorFlow(TF)是怎么样工作的. TF的核心是用C++写的,这样的好处是运行快,缺点是调用不灵活.而Python恰好相反,所以结合两种语言的优势.涉及计算的核心算子和运行框架是用C++写的,并提供API给Python.Python调用这些A

  • TensorFlow高效读取数据的方法示例

    概述 最新上传的mcnn中有完整的数据读写示例,可以参考. 关于Tensorflow读取数据,官网给出了三种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据. 从文件读取数据: 在TensorFlow图的起始, 让一个输入管线从文件中读取数据. 预加载数据: 在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况). 对于数据量较小而言,可能一般选择直接将数据加载进内存,然后再分batch输入网络进行训练(t

  • tensorflow TFRecords文件的生成和读取的方法

    TensorFlow提供了TFRecords的格式来统一存储数据,理论上,TFRecords可以存储任何形式的数据. TFRecords文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的.以下的代码给出了tf.train.Example的定义. message Example { Features features = 1; }; message Features { map<string, Feature> feature = 1; }; mes

  • tensorflow学习笔记之tfrecord文件的生成与读取

    训练模型时,我们并不是直接将图像送入模型,而是先将图像转换为tfrecord文件,再将tfrecord文件送入模型.为进一步理解tfrecord文件,本例先将6幅图像及其标签转换为tfrecord文件,然后读取tfrecord文件,重现6幅图像及其标签. 1.生成tfrecord文件 import os import numpy as np import tensorflow as tf from PIL import Image filenames = [ 'images/cat/1.jpg'

  • Tensorflow中TFRecord生成与读取的实现

    目录 一.为什么使用TFRecord? 二. 生成TFRecord简单实现方式 三. 生成TFRecord文件完整代码实例 TFRecord读取 四. 读取TFRecord的简单实现方式 五.tf.contrib.slim模块读取TFrecord文件完整代码实例 参考: 一.为什么使用TFRecord? 正常情况下我们训练文件夹经常会生成 train, test 或者val文件夹,这些文件夹内部往往会存着成千上万的图片或文本等文件,这些文件被散列存着,这样不仅占用磁盘空间,并且再被一个个读取的时

  • Tensorflow中批量读取数据的案列分析及TFRecord文件的打包与读取

    单一数据读取方式: 第一种:slice_input_producer() # 返回值可以直接通过 Session.run([images, labels])查看,且第一个参数必须放在列表中,如[...] [images, labels] = tf.train.slice_input_producer([images, labels], num_epochs=None, shuffle=True) 第二种:string_input_producer() # 需要定义文件读取器,然后通过读取器中的

  • python神经网络tfrecords文件的写入读取及内容解析

    目录 学习前言 tfrecords格式是什么 tfrecords的写入 tfrecords的读取 测试代码 1.tfrecords文件的写入 2.tfrecords文件的读取 学习前言 前一段时间对SSD预测与训练的整体框架有了一定的了解,但是对其中很多细节还是把握的不清楚.今天我决定好好了解以下tfrecords文件的构造. tfrecords格式是什么 tfrecords是一种二进制编码的文件格式,tensorflow专用.能将任意数据转换为tfrecords.更好的利用内存,更方便复制和移

  • PHP读取、解析eml文件及生成网页的方法示例

    本文实例讲述了PHP读取.解析eml文件及生成网页的方法.分享给大家供大家参考,具体如下: php读取eml实例,本实例可以将导出eml文件解析成正文,并且可以将附件保存到服务器.不多说直接贴代码了. <?php // Author: richard e42083458@163.com // gets parameters error_reporting(E_ALL ^ (E_WARNING|E_NOTICE)); header("Content-type: text/html; char

  • c#使用Dataset读取XML文件动态生成菜单的方法

    本文实例讲述了c#使用Dataset读取XML文件动态生成菜单的方法.分享给大家供大家参考.具体实现方法如下: Step 1:Form1 上添加一个ToolStripContainer控件 Step2:实现代码 private void Form2_Load(object sender, EventArgs e) { CMenuEx menu = new CMenuEx(); string sPath = "D://Menu.xml";//xml的内容 if (menu.FileExi

  • php生成与读取excel文件

    在网站中经常会生成表格,CSV和Excel都是常用的报表格式,CSV相对来说比较简单,如果大家有疑问我会相继发布一些CSV的实例,这里主要介绍用PHP来生成和读取Excel文件. 要执行下面的函数,首先要引入一个类库:PHPExcel,PHPExcel是一个强大的PHP类库,用来读写不同的文件格式,比如说Excel 2007,PDF格式,HTML格式等等,这个类库是建立在Microsoft's OpenXML和PHP 的基础上的,对Excel提供的强大的支持,比如设置工作薄,字体样式,图片以及边

  • python读取文件名称生成list的方法

    经常需要读取某个文件夹下所有的图像文件. 我使用python写了个简单的代码,读取某个文件夹下某个后缀的文件,将文件名生成为文本(csv格式) import fnmatch import os import pandas as pd import numpy as np import sys InputStra = sys.argv[1] InputStrb = sys.argv[2] def ReadSaveAddr(Stra,Strb): #print(Stra) #print(Strb)

  • python深度学习TensorFlow神经网络模型的保存和读取

    目录 之前的笔记里实现了softmax回归分类.简单的含有一个隐层的神经网络.卷积神经网络等等,但是这些代码在训练完成之后就直接退出了,并没有将训练得到的模型保存下来方便下次直接使用.为了让训练结果可以复用,需要将训练好的神经网络模型持久化,这就是这篇笔记里要写的东西. TensorFlow提供了一个非常简单的API,即tf.train.Saver类来保存和还原一个神经网络模型. 下面代码给出了保存TensorFlow模型的方法: import tensorflow as tf # 声明两个变量

随机推荐