将自己的数据集制作成TFRecord格式教程

在使用TensorFlow训练神经网络时,首先面临的问题是:网络的输入

此篇文章,教大家将自己的数据集制作成TFRecord格式,feed进网络,除了TFRecord格式,TensorFlow也支持其他格

式的数据,此处就不再介绍了。建议大家使用TFRecord格式,在后面可以通过api进行多线程的读取文件队列。

1. 原本的数据集

此时,我有两类图片,分别是xiansu100,xiansu60,每一类中有10张图片。

2.制作成TFRecord格式

tfrecord会根据你选择输入文件的类,自动给每一类打上同样的标签。如在本例中,只有0,1 两类,想知道文件夹名与label关系的,可以自己保存起来。

#生成整数型的属性
def _int64_feature(value):
 return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))

#生成字符串类型的属性
def _bytes_feature(value):
 return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))

#制作TFRecord格式
def createTFRecord(filename,mapfile):
 class_map = {}
 data_dir = '/home/wc/DataSet/traffic/testTFRecord/'
 classes = {'xiansu60','xiansu100'}
 #输出TFRecord文件的地址

 writer = tf.python_io.TFRecordWriter(filename)

 for index,name in enumerate(classes):
  class_path=data_dir+name+'/'
  class_map[index] = name
  for img_name in os.listdir(class_path):
   img_path = class_path + img_name #每个图片的地址
   img = Image.open(img_path)
   img= img.resize((224,224))
   img_raw = img.tobytes()   #将图片转化成二进制格式
   example = tf.train.Example(features = tf.train.Features(feature = {
    'label':_int64_feature(index),
    'image_raw': _bytes_feature(img_raw)
   }))
   writer.write(example.SerializeToString())
 writer.close()

 txtfile = open(mapfile,'w+')
 for key in class_map.keys():
  txtfile.writelines(str(key)+":"+class_map[key]+"\n")
 txtfile.close()

此段代码,运行完后会产生生成的.tfrecord文件。

3. 读取TFRecord的数据,进行解析,此时使用了文件队列以及多线程

#读取train.tfrecord中的数据
def read_and_decode(filename):
 #创建一个reader来读取TFRecord文件中的样例
 reader = tf.TFRecordReader()
 #创建一个队列来维护输入文件列表
 filename_queue = tf.train.string_input_producer([filename], shuffle=False,num_epochs = 1)
 #从文件中读出一个样例,也可以使用read_up_to一次读取多个样例
 _,serialized_example = reader.read(filename_queue)
#  print _,serialized_example

 #解析读入的一个样例,如果需要解析多个,可以用parse_example
 features = tf.parse_single_example(
 serialized_example,
 features = {'label':tf.FixedLenFeature([], tf.int64),
    'image_raw': tf.FixedLenFeature([], tf.string),})
 #将字符串解析成图像对应的像素数组
 img = tf.decode_raw(features['image_raw'], tf.uint8)
 img = tf.reshape(img,[224, 224, 3]) #reshape为128*128*3通道图片
 img = tf.image.per_image_standardization(img)
 labels = tf.cast(features['label'], tf.int32)
 return img, labels

4. 将图片几个一打包,形成batch

def createBatch(filename,batchsize):
 images,labels = read_and_decode(filename)

 min_after_dequeue = 10
 capacity = min_after_dequeue + 3 * batchsize

 image_batch, label_batch = tf.train.shuffle_batch([images, labels],
              batch_size=batchsize,
              capacity=capacity,
              min_after_dequeue=min_after_dequeue
              )

 label_batch = tf.one_hot(label_batch,depth=2)
 return image_batch, label_batch

5.主函数

if __name__ =="__main__":
 #训练图片两张为一个batch,进行训练,测试图片一起进行测试
 mapfile = "/home/wc/DataSet/traffic/testTFRecord/classmap.txt"
 train_filename = "/home/wc/DataSet/traffic/testTFRecord/train.tfrecords"
#  createTFRecord(train_filename,mapfile)
 test_filename = "/home/wc/DataSet/traffic/testTFRecord/test.tfrecords"
#  createTFRecord(test_filename,mapfile)
 image_batch, label_batch = createBatch(filename = train_filename,batchsize = 2)
 test_images,test_labels = createBatch(filename = test_filename,batchsize = 20)
 with tf.Session() as sess:
  initop = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
  sess.run(initop)
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess = sess, coord = coord)

  try:
   step = 0
   while 1:
    _image_batch,_label_batch = sess.run([image_batch,label_batch])
    step += 1
    print step
    print (_label_batch)
  except tf.errors.OutOfRangeError:
   print (" trainData done!")

  try:
   step = 0
   while 1:
    _test_images,_test_labels = sess.run([test_images,test_labels])
    step += 1
    print step
 #     print _image_batch.shape
    print (_test_labels)
  except tf.errors.OutOfRangeError:
   print (" TEST done!")
  coord.request_stop()
  coord.join(threads)

此时,生成的batch,就可以feed进网络了。

以上这篇将自己的数据集制作成TFRecord格式教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • Tensorflow中使用tfrecord方式读取数据的方法

    前言 本博客默认读者对神经网络与Tensorflow有一定了解,对其中的一些术语不再做具体解释.并且本博客主要以图片数据为例进行介绍,如有错误,敬请斧正. 使用Tensorflow训练神经网络时,我们可以用多种方式来读取自己的数据.如果数据集比较小,而且内存足够大,可以选择直接将所有数据读进内存,然后每次取一个batch的数据出来.如果数据较多,可以每次直接从硬盘中进行读取,不过这种方式的读取效率就比较低了.此篇博客就主要讲一下Tensorflow官方推荐的一种较为高效的数据读取方式--tfre

  • 使用TFRecord存取多个数据案例

    TensorFlow提供了一种统一的格式来存储数据,就是TFRecord,它可以统一不同的原始数据格式,并且更加有效地管理不同的属性. TFRecord格式 TFRecord文件中的数据都是用tf.train.Example Protocol Buffer的格式来存储的,tf.train.Example可以被定义为: message Example{ Features features = 1 } message Features{ map<string, Feature> feature =

  • Tensorflow使用tfrecord输入数据格式

    Tensorflow 提供了一种统一的格式来存储数据,这个格式就是TFRecord,上一篇文章中所提到的方法当数据的来源更复杂,每个样例中的信息更丰富的时候就很难有效的记录输入数据中的信息了,于是Tensorflow提供了TFRecord来统一存储数据,接下来我们就来介绍如何使用TFRecord来同意输入数据的格式. 1. TFRecord格式介绍 TFRecord文件中的数据是通过tf.train.Example Protocol Buffer的格式存储的,下面是tf.train.Exampl

  • Tensorflow 实现将图像与标签数据转化为tfRecord文件

    tensorflow中如果要对神经网络模型进行训练,需要把训练数据转换为tfrecord格式才能被读取,tensorflow的model文件里直接提供了相应的脚本文件在下面的文件夹中: cd tensorflow/models/research/object_detection/dataset_tools 其中包括: 1.create_coco_tf_record.py:注意,这个代码需要解析json格式的标签文件 2.create_pascal_tf_record.py:注意,这个代码需要解析

  • tensorflow入门:TFRecordDataset变长数据的batch读取详解

    在上一篇文章tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用里,讲到了使用如何使用tf.data.TFRecordDatase来对tfrecord文件进行batch读取,即使用dataset的batch方法进行:但如果每条数据的长度不一样(常见于语音.视频.NLP等领域),则不能直接用batch方法获取数据,这时则有两个解决办法: 1.在把数据写入tfrecord时,先把数据pad到统一的长度再写入tfrecord:这个方法的问题在于:若是有大量

  • Tensorflow之构建自己的图片数据集TFrecords的方法

    学习谷歌的深度学习终于有点眉目了,给大家分享我的Tensorflow学习历程. tensorflow的官方中文文档比较生涩,数据集一直采用的MNIST二进制数据集.并没有过多讲述怎么构建自己的图片数据集tfrecords. 流程是:制作数据集-读取数据集--加入队列 先贴完整的代码: #encoding=utf-8 import os import tensorflow as tf from PIL import Image cwd = os.getcwd() classes = {'test'

  • TFRecord格式存储数据与队列读取实例

    Tensor Flow官方网站上提供三种读取数据的方法 1. 预加载数据:在Tensor Flow图中定义常量或变量来保存所有数据,将数据直接嵌到数据图中,当训练数据较大时,很消耗内存. 如 x1=tf.constant([0,1]) x2=tf.constant([1,0]) y=tf.add(x1,x2) 2.填充数据:使用sess.run()的feed_dict参数,将Python产生的数据填充到后端,之前的MNIST数据集就是通过这种方法.也有消耗内存,数据类型转换耗时的缺点. 3. 从

  • 将自己的数据集制作成TFRecord格式教程

    在使用TensorFlow训练神经网络时,首先面临的问题是:网络的输入 此篇文章,教大家将自己的数据集制作成TFRecord格式,feed进网络,除了TFRecord格式,TensorFlow也支持其他格 式的数据,此处就不再介绍了.建议大家使用TFRecord格式,在后面可以通过api进行多线程的读取文件队列. 1. 原本的数据集 此时,我有两类图片,分别是xiansu100,xiansu60,每一类中有10张图片. 2.制作成TFRecord格式 tfrecord会根据你选择输入文件的类,自

  • 将数据集制作成VOC数据集格式的实例

    在做目标检测任务时,若使用Github已复现的论文时,需首先将自己的数据集转化为VOC数据集的格式,因为论文作者使用的是公开数据集VOC 2007.VOC2012.COCO等类型数据集做方法验证与比对. 一.VOC数据集格式 --VOCdevkit2007 --VOC2007 --Annotations (xml格式的文件) --000001.xml --ImageSets --Layout --Main --train.txt --test.txt --val.txt --trainval.t

  • Python爬取读者并制作成PDF

    学了下beautifulsoup后,做个个网络爬虫,爬取读者杂志并用reportlab制作成pdf.. crawler.py 复制代码 代码如下: #!/usr/bin/env python #coding=utf-8 """     Author:         Anemone     Filename:       getmain.py     Last modified:  2015-02-19 16:47     E-mail:         anemone@82

  • 使用PlatformView将 Android 控件view制作成Flutter插件

    目录 引言 1. FlutterPlugin 创建 2. 创建 Android 控件 3. 注册 Android 控件 4. 封装 Android 层通信交互 ‘CustomViewController’ 代码说明 5. 在 flutter 中如何使用已注册的 Android 控件(view) 代码说明 如何使用这个View 6. 附上 example 完整代码 引言 小编最近在项目中实现相机识别人脸的功能,将 Android 封装的控件 view 进行中转,制作成 FlutterPlugin

  • activex 控件制作成cab包的问题

    我的控件是自己写的,现在需要把它做成.cab包,以使其可以在客户端自动下载注册.我只知道使用cabarc.exe这个工具.但是不知道怎么写inf文件,怎么加入证书.ActiveX发布步骤:      创建PVK文件[私人密匙文件]      makecert   -sk   DigitalTitan   DigitalTitan.pvk      makecert   -n   CN=TelStar   TelStar      创建CER文件[公司证书]      makecert   -sk

  • java实现可安装的exe程序实例详解

    java实现可安装的exe程序实例详解 通过编写Java代码,实现可安装的exe文件的一般思路: 1.在eclipse中创建java项目,然后编写Java代码,将编写好的Java项目导出一个.jar格式的jar包: 2.通过安装exe4j软件,将导出的.jar格式的文件制作成.exe格式的可执行的文件,(注意:此时的.exe文件只是可以执行,还不能够安装): 3.通过安装Inno setup软件,将可执行的.exe格式的文件..jar格式的文件以及其它需要的文件制作成一个可安装的.exe格式的文

  • Maven在Java8下如何忽略Javadoc的编译错误详解

    JavaDoc简介And基础知识 (一) Java注释类型 //用于单行注释. /*...*/用于多行注释,从/*开始,到*/结束,不能嵌套. /**...*/则是为支持jdk工具javadoc.exe而特有的注释语句. 说明:javadoc 工具能从java源文件中读取第三种注释,并能识别注释中用@标识的一些特殊变量(见表),制作成Html格式的类说明文档.javadoc不但能对一个 java源文件生成注释文档,而且能对目录和包生成交叉链接的html格式的类说明文档,十分方便. (二)Java

  • TensorFlow数据输入的方法示例

    读取数据(Reading data) TensorFlow输入数据的方式有四种: tf.data API:可以很容易的构建一个复杂的输入通道(pipeline)(首选数据输入方式)(Eager模式必须使用该API来构建输入通道) Feeding:使用Python代码提供数据,然后将数据feeding到计算图中. QueueRunner:基于队列的输入通道(在计算图计算前从队列中读取数据) Preloaded data:用一个constant常量将数据集加载到计算图中(主要用于小数据集) 1. t

  • Python实现将DNA序列存储为tfr文件并读取流程介绍

    最近导师让我跑模型,生物信息方向的,我一个学计算机的,好多东西都看不明白.现在的方向大致是,用深度学习的模型预测病毒感染人类的风险. 既然是病毒,就需要拿到它的DNA,也就是碱基序列,然后把这些ACGT序列丢进模型里面,然后就是预测能不能感染人类,说实话,估计结果不会好,现在啥都是transformer,而且我看的这篇论文,我认为仅仅从DNA序列大概预测不出什么东西. 但是就那样吧,现在数据去哪里下载,需要下载什么样的数据,下载完成后怎么处理我还是一脸懵逼,但是假设上面都处理好了,然后即使把数据

随机推荐