python神经网络tfrecords文件的写入读取及内容解析

目录
  • 学习前言
  • tfrecords格式是什么
  • tfrecords的写入
  • tfrecords的读取
  • 测试代码
    • 1、tfrecords文件的写入
    • 2、tfrecords文件的读取

学习前言

前一段时间对SSD预测与训练的整体框架有了一定的了解,但是对其中很多细节还是把握的不清楚。今天我决定好好了解以下tfrecords文件的构造。

tfrecords格式是什么

tfrecords是一种二进制编码的文件格式,tensorflow专用。能将任意数据转换为tfrecords。更好的利用内存,更方便复制和移动,并且不需要单独的标签文件。

之所以使用到tfrecords格式是因为当今数据爆炸的情况下,使用普通的数据格式不仅麻烦,而且速度慢,这种专门为tensorflow定制的数据格式可以大大增快数据的读取,而且将所有内容规整,在保证速度的情况下,使得数据更加简单明晰。

tfrecords的写入

这个例子将会讲述如何将MNIST数据集写入到tfrecords,本次用到的MNIST数据集会利用tensorflow原有的库进行导入。

from tensorflow.examples.tutorials.mnist import input_data
# 读取MNIST数据集
mnist = input_data.read_data_sets('./MNIST_data', dtype=tf.float32, one_hot=True)

对于MNIST数据集而言,其中的训练集是mnist.train,而它的数据可以分为images和labels,可通过如下方式获得。

# 获得image,shape为(55000,784)
images = mnist.train.images
# 获得label,shape为(55000,10)
labels = mnist.train.labels
# 获得一共具有多少张图片
num_examples = mnist.train.num_examples

接下来定义存储TFRecord文件的地址,同时创建一个writer来写TFRecord文件。

# 存储TFRecord文件的地址
filename = 'record/output.tfrecords'
# 创建一个writer来写TFRecord文件
writer = tf.python_io.TFRecordWriter(filename)

此时便可以按照一定的格式写入了,此时需要对每一张图片进行循环并写入,在tf.train.Features中利用features字典定义了数据保存的方式。以image_raw为例,其经过函数_float_feature处理后,存储到tfrecords文件的’image/encoded’位置上。

# 将每张图片都转为一个Example,并写入
for i in range(num_examples):
    image_raw = images[i]  # 读取每一幅图像
    image_string = images[i].tostring()
    example = tf.train.Example(
        features=tf.train.Features(
            feature={
                'image/class/label': _int64_feature(np.argmax(labels[i])),
                'image/encoded': _float_feature(image_raw),
                'image/encoded_tostring': _bytes_feature(image_string)
            }
        )
    )
    print(i,"/",num_examples)
    writer.write(example.SerializeToString())  # 将Example写入TFRecord文件

在最终存入前,数据还需要经过处理,处理方式如下:

# 生成整数的属性
def _int64_feature(value):
    if not isinstance(value,list) and not isinstance(value,np.ndarray):
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
# 生成浮点数的属性
def _float_feature(value):
    if not isinstance(value,list) and not isinstance(value,np.ndarray):
        value = [value]
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))
# 生成字符串型的属性
def _bytes_feature(value):
    if not isinstance(value,list) and not isinstance(value,np.ndarray):
        value = [value]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

tfrecords的读取

tfrecords的读取首先要创建一个reader来读取TFRecord文件中的Example。

# 创建一个reader来读取TFRecord文件中的Example
reader = tf.TFRecordReader()

再创建一个队列来维护输入文件列表。

# 创建一个队列来维护输入文件列表
filename_queue = tf.train.string_input_producer(['record/output.tfrecords'])

利用reader读取输入文件列表队列,并用parse_single_example将读入的Example解析成tensor

# 从文件中读出一个Example
_, serialized_example = reader.read(filename_queue)
# 用parse_single_example将读入的Example解析成tensor
features = tf.parse_single_example(
    serialized_example,
    features={
        'image/class/label': tf.FixedLenFeature([], tf.int64),
        'image/encoded': tf.FixedLenFeature([784], tf.float32, default_value=tf.zeros([784], dtype=tf.float32)),
        'image/encoded_tostring': tf.FixedLenFeature([], tf.string)
    }
)

此时我们得到了一个features,实际上它是一个类似于字典的东西,我们额可以通过字典的方式读取它内部的内容,而字典的索引就是我们再写入tfrecord文件时所用的feature。

# 将字符串解析成图像对应的像素数组
labels = tf.cast(features['image/class/label'], tf.int32)
images = tf.cast(features['image/encoded'], tf.float32)
images_tostrings = tf.decode_raw(features['image/encoded_tostring'], tf.float32)

最后利用一个循环输出:

# 每次运行读取一个Example。当所有样例读取完之后,在此样例中程序会重头读取
for i in range(5):
    label, image = sess.run([labels, images])
    images_tostring = sess.run(images_tostrings)
    print(np.shape(image))
    print(np.shape(images_tostring))
    print(label)
    print("#########################")

测试代码

1、tfrecords文件的写入

import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 生成整数的属性
def _int64_feature(value):
    if not isinstance(value,list) and not isinstance(value,np.ndarray):
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
# 生成浮点数的属性
def _float_feature(value):
    if not isinstance(value,list) and not isinstance(value,np.ndarray):
        value = [value]
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))
# 生成字符串型的属性
def _bytes_feature(value):
    if not isinstance(value,list) and not isinstance(value,np.ndarray):
        value = [value]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
# 读取MNIST数据集
mnist = input_data.read_data_sets('./MNIST_data', dtype=tf.float32, one_hot=True)
# 获得image,shape为(55000,784)
images = mnist.train.images
# 获得label,shape为(55000,10)
labels = mnist.train.labels
# 获得一共具有多少张图片
num_examples = mnist.train.num_examples
# 存储TFRecord文件的地址
filename = 'record/Mnist_Out.tfrecords'
# 创建一个writer来写TFRecord文件
writer = tf.python_io.TFRecordWriter(filename)
# 将每张图片都转为一个Example,并写入
for i in range(num_examples):
    image_raw = images[i]  # 读取每一幅图像
    image_string = images[i].tostring()
    example = tf.train.Example(
        features=tf.train.Features(
            feature={
                'image/class/label': _int64_feature(np.argmax(labels[i])),
                'image/encoded': _float_feature(image_raw),
                'image/encoded_tostring': _bytes_feature(image_string)
            }
        )
    )
    print(i,"/",num_examples)
    writer.write(example.SerializeToString())  # 将Example写入TFRecord文件
print('data processing success')
writer.close()

运行结果为:

……
54993 / 55000
54994 / 55000
54995 / 55000
54996 / 55000
54997 / 55000
54998 / 55000
54999 / 55000
data processing success

2、tfrecords文件的读取

import tensorflow as tf
import numpy as np
# 创建一个reader来读取TFRecord文件中的Example
reader = tf.TFRecordReader()
# 创建一个队列来维护输入文件列表
filename_queue = tf.train.string_input_producer(['record/Mnist_Out.tfrecords'])
# 从文件中读出一个Example
_, serialized_example = reader.read(filename_queue)
# 用parse_single_example将读入的Example解析成tensor
features = tf.parse_single_example(
    serialized_example,
    features={
        'image/class/label': tf.FixedLenFeature([], tf.int64),
        'image/encoded': tf.FixedLenFeature([784], tf.float32, default_value=tf.zeros([784], dtype=tf.float32)),
        'image/encoded_tostring': tf.FixedLenFeature([], tf.string)
    }
)
# 将字符串解析成图像对应的像素数组
labels = tf.cast(features['image/class/label'], tf.int32)
images = tf.cast(features['image/encoded'], tf.float32)
images_tostrings = tf.decode_raw(features['image/encoded_tostring'], tf.float32)
sess = tf.Session()
# 启动多线程处理输入数据
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# 每次运行读取一个Example。当所有样例读取完之后,在此样例中程序会重头读取
for i in range(5):
    label, image = sess.run([labels, images])
    images_tostring = sess.run(images_tostrings)
    print(np.shape(image))
    print(np.shape(images_tostring))
    print(label)
    print("#########################")

运行结果为:

#########################
(784,)
(784,)
7
#########################
#########################
(784,)
(784,)
4
#########################
#########################
(784,)
(784,)
1
#########################
#########################
(784,)
(784,)
1
#########################
#########################
(784,)
(784,)
9
#########################

以上就是python神经网络tfrecords文件的写入读取及内容解析的详细内容,更多关于python神经网络tfrecords写入读取的资料请关注我们其它相关文章!

(0)

相关推荐

  • python进阶TensorFlow神经网络拟合线性及非线性函数

    目录 一.拟合线性函数 生成随机坐标 神经网络拟合 代码 二.拟合非线性函数 生成二次随机点 神经网络拟合 代码 一.拟合线性函数 学习率0.03,训练1000次: 学习率0.05,训练1000次: 学习率0.1,训练1000次: 可以发现,学习率为0.05时的训练效果是最好的. 生成随机坐标 1.生成x坐标 2.生成随机干扰 3.计算得到y坐标 4.画点 # 生成随机点 def Produce_Random_Data(): global x_data, y_data # 生成x坐标 x_dat

  • Python深度学习pytorch神经网络多输入多输出通道

    目录 多输入通道 多输出通道 1 × 1 1\times1 1×1卷积层 虽然每个图像具有多个通道和多层卷积层.例如彩色图像具有标准的RGB通道来指示红.绿和蓝.但是到目前为止,我们仅展示了单个输入和单个输出通道的简化例子.这使得我们可以将输入.卷积核和输出看作二维张量. 当我们添加通道时,我们的输入和隐藏的表示都变成了三维张量.例如,每个RGB输入图像具有 3 × h × w 3\times{h}\times{w} 3×h×w的形状.我们将这个大小为3的轴称为通道(channel)维度.在本节

  • Python深度学习pytorch神经网络图像卷积运算详解

    目录 互相关运算 卷积层 特征映射 由于卷积神经网络的设计是用于探索图像数据,本节我们将以图像为例. 互相关运算 严格来说,卷积层是个错误的叫法,因为它所表达的运算其实是互相关运算(cross-correlation),而不是卷积运算.在卷积层中,输入张量和核张量通过互相关运算产生输出张量. 首先,我们暂时忽略通道(第三维)这一情况,看看如何处理二维图像数据和隐藏表示.下图中,输入是高度为3.宽度为3的二维张量(即形状为 3 × 3 3\times3 3×3).卷积核的高度和宽度都是2. 注意,

  • TFRecord格式存储数据与队列读取实例

    Tensor Flow官方网站上提供三种读取数据的方法 1. 预加载数据:在Tensor Flow图中定义常量或变量来保存所有数据,将数据直接嵌到数据图中,当训练数据较大时,很消耗内存. 如 x1=tf.constant([0,1]) x2=tf.constant([1,0]) y=tf.add(x1,x2) 2.填充数据:使用sess.run()的feed_dict参数,将Python产生的数据填充到后端,之前的MNIST数据集就是通过这种方法.也有消耗内存,数据类型转换耗时的缺点. 3. 从

  • Tensorflow中使用tfrecord方式读取数据的方法

    前言 本博客默认读者对神经网络与Tensorflow有一定了解,对其中的一些术语不再做具体解释.并且本博客主要以图片数据为例进行介绍,如有错误,敬请斧正. 使用Tensorflow训练神经网络时,我们可以用多种方式来读取自己的数据.如果数据集比较小,而且内存足够大,可以选择直接将所有数据读进内存,然后每次取一个batch的数据出来.如果数据较多,可以每次直接从硬盘中进行读取,不过这种方式的读取效率就比较低了.此篇博客就主要讲一下Tensorflow官方推荐的一种较为高效的数据读取方式--tfre

  • tensorflow TFRecords文件的生成和读取的方法

    TensorFlow提供了TFRecords的格式来统一存储数据,理论上,TFRecords可以存储任何形式的数据. TFRecords文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的.以下的代码给出了tf.train.Example的定义. message Example { Features features = 1; }; message Features { map<string, Feature> feature = 1; }; mes

  • python神经网络tfrecords文件的写入读取及内容解析

    目录 学习前言 tfrecords格式是什么 tfrecords的写入 tfrecords的读取 测试代码 1.tfrecords文件的写入 2.tfrecords文件的读取 学习前言 前一段时间对SSD预测与训练的整体框架有了一定的了解,但是对其中很多细节还是把握的不清楚.今天我决定好好了解以下tfrecords文件的构造. tfrecords格式是什么 tfrecords是一种二进制编码的文件格式,tensorflow专用.能将任意数据转换为tfrecords.更好的利用内存,更方便复制和移

  • python对csv文件追加写入列的方法

    python对csv文件追加写入列,具体内容如下所示: 原始数据 [外链图片转存失败(img-zQSQWAyQ-1563597916666)(C:\Users\innduce\AppData\Roaming\Typora\typora-user-images\1557663419920.png)] import pandas as pd import numpy as np data = pd.read_csv(r'平均值.csv') print(data.columns)#获取列索引值 dat

  • Python中文件的写入读取以及附加文字方法

    今天学习到python的读取文件部分. 还是以一段代码为例: filename='programming.txt' with open(filename,'w') as file_object: file_object.write("I love programming.\n") file_object.write("I love travelling.\n") 在这里调用open打开文件,两个实参,一个是要打开的文件名称,第二个实参('w')是告诉Python我们

  • 基于Python实现大文件分割和命名脚本过程解析

    日志文件分割.命名 工作中经常会收到测试同学.客户同学提供的日志文件,其中不乏几百M一G的也都有,毕竟压测一晚上产生的日志量还是很可观的,xDxD,因此不可避免的需要对日志进行分割,通常定位问题需要针对时间点,因此最好对分割后的日志文件使用文件中日志的开始.结束时间点来命名,这样使用起来最为直观,下面给大家分享两个脚本,分别作分割.命名,希望能够给大家提供一点点帮助: 大文件分割 用法: python split_big_file.py 输入文件全路径名 输入期望的分割后每个小文件的行数 Jus

  • python神经网络slim常用函数训练保存模型

    目录 学习前言 slim是什么 slim常用函数 1.slim = tf.contrib.slim 2.slim.create_global_step 3.slim.dataset.Dataset 4.slim.dataset_data_provider.DatasetDataProvider 5.slim.conv2d 6.slim.max_pool2d 7.slim.fully_connected 8.slim.learning.train 本次博文实现的目标 整体框架构建思路 1.整体框架

  • Python读写压缩文件的方法

    问题 你想读写一个gzip或bz2格式的压缩文件. 解决方案 gzip 和 bz2 模块可以很容易的处理这些文件. 两个模块都为 open() 函数提供了另外的实现来解决这个问题. 比如,为了以文本形式读取压缩文件,可以这样做: # gzip compression import gzip with gzip.open('somefile.gz', 'rt') as f: text = f.read() # bz2 compression import bz2 with bz2.open('so

  • Python如何获取文件路径/目录

    一.获取文件路径实现 1.1 获取当前文件路径 import os current_file_path = __file__ print(f"current_file_path: {current_file_path}") __file__变量其实有个问题,当文件被是被调用文件时__file__总是文件的绝对路径:但当文件是直接被执行的文件时,__file__并不总是文件的绝对路径,而是你执行该文件时给python传的路径.比如你是python xxx/yyy.py形式执行的,那么此时

  • python删除服务器文件代码示例

    本文主要研究的是Python编程删除服务器文件,具体实现 代码如下. 实例1 #coding:utf-8 import paramiko """ 创建文件 删除文件 root权限 """ ssh=paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) ssh.connect(hostname="192.168.1.37",po

  • Python 如何测试文件是否存在

    问题 你想测试一个文件或目录是否存在. 解决方案 使用 os.path 模块来测试一个文件或目录是否存在.比如: >>> import os >>> os.path.exists('/etc/passwd') True >>> os.path.exists('/tmp/spam') False >>> 你还能进一步测试这个文件时什么类型的. 在下面这些测试中,如果测试的文件不存在的时候,结果都会返回False: >>>

  • Python修改DBF文件指定列

    一.需求: 某公司每日收到一批DBF文件,A系统实时处理后将其中dealstat字段置为1(已处理).现在每日晚间B系统也需要处理该文件,因此需将文件中dealstat字段修改为空(未处理). 二.分析: 1.应创建副本进行修改 解答:使用shutil.copy 2.修改DBF 解答:使用dbf模块.此模块能找到的文档比较旧,需要结合代码进行理解. 三.代码实现: #!/usr/bin/env python # _*_ coding:utf-8 _*_ """ @Time :

随机推荐