Tensorflow分类器项目自定义数据读入的实现

在照着Tensorflow官网的demo敲了一遍分类器项目的代码后,运行倒是成功了,结果也不错。但是最终还是要训练自己的数据,所以尝试准备加载自定义的数据,然而demo中只是出现了fashion_mnist.load_data()并没有详细的读取过程,随后我又找了些资料,把读取的过程记录在这里。

首先提一下需要用到的模块:

import os
import keras
import matplotlib.pyplot as plt
from PIL import Image
from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split

图片分类器项目,首先确定你要处理的图片分辨率将是多少,这里的例子为30像素:

IMG_SIZE_X = 30
IMG_SIZE_Y = 30

其次确定你图片的方式目录:

image_path = r'D:\Projects\ImageClassifier\data\set'
path = ".\data"
# 你也可以使用相对路径的方式
# image_path =os.path.join(path, "set")

目录下的结构如下:

相应的label.txt如下:

动漫
风景
美女
物语
樱花

接下来是接在labels.txt,如下:

label_name = "labels.txt"
label_path = os.path.join(path, label_name)
class_names = np.loadtxt(label_path, type(""))

这里简便起见,直接利用了numpy的loadtxt函数直接加载。

之后便是正式处理图片数据了,注释就写在里面了:

re_load = False
re_build = False
# re_load = True
re_build = True

data_name = "data.npz"
data_path = os.path.join(path, data_name)
model_name = "model.h5"
model_path = os.path.join(path, model_name)

count = 0

# 这里判断是否存在序列化之后的数据,re_load是一个开关,是否强制重新处理,测试用,可以去除。
if not os.path.exists(data_path) or re_load:
  labels = []
  images = []
  print('Handle images')
  # 由于label.txt是和图片防止目录的分类目录一一对应的,即每个子目录的目录名就是labels.txt里的一个label,所以这里可以通过读取class_names的每一项去拼接path后读取
  for index, name in enumerate(class_names):
    # 这里是拼接后的子目录path
    classpath = os.path.join(image_path, name)
    # 先判断一下是否是目录
    if not os.path.isdir(classpath):
      continue
    # limit是测试时候用的这里可以去除
    limit = 0
    for image_name in os.listdir(classpath):
      if limit >= max_size:
        break
      # 这里是拼接后的待处理的图片path
      imagepath = os.path.join(classpath, image_name)
      count = count + 1
      limit = limit + 1
      # 利用Image打开图片
      img = Image.open(imagepath)
      # 缩放到你最初确定要处理的图片分辨率大小
      img = img.resize((IMG_SIZE_X, IMG_SIZE_Y))
      # 转为灰度图片,这里彩色通道会干扰结果,并且会加大计算量
      img = img.convert("L")
      # 转为numpy数组
      img = np.array(img)
      # 由(30,30)转为(1,30,30)(即`channels_first`),当然你也可以转换为(30,30,1)(即`channels_last`)但为了之后预览处理后的图片方便这里采用了(1,30,30)的格式存放
      img = np.reshape(img, (1, IMG_SIZE_X, IMG_SIZE_Y))
      # 这里利用循环生成labels数据,其中存放的实际是class_names中对应元素的索引
      labels.append([index])
      # 添加到images中,最后统一处理
      images.append(img)
      # 循环中一些状态的输出,可以去除
      print("{} class: {} {} limit: {} {}"
         .format(count, index + 1, class_names[index], limit, imagepath))
  # 最后一次性将images和labels都转换成numpy数组
  npy_data = np.array(images)
  npy_labels = np.array(labels)
  # 处理数据只需要一次,所以我们选择在这里利用numpy自带的方法将处理之后的数据序列化存储
  np.savez(data_path, x=npy_data, y=npy_labels)
  print("Save images by npz")
else:
  # 如果存在序列化号的数据,便直接读取,提高速度
  npy_data = np.load(data_path)["x"]
  npy_labels = np.load(data_path)["y"]
  print("Load images by npz")
image_data = npy_data
labels_data = npy_labels

到了这里原始数据的加工预处理便已经完成,只需要最后一步,就和demo中fashion_mnist.load_data()返回的结果一样了。代码如下:

# 最后一步就是将原始数据分成训练数据和测试数据
train_images, test_images, train_labels, test_labels = \
  train_test_split(image_data, labels_data, test_size=0.2, random_state=6)

这里将相关信息打印的方法也附上:

print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Image Data", image_data.shape))
print("%-28s %-s" % ("Labels Data", labels_data.shape))
print("=================================================================")

print('Split train and test data,p=%')
print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Train Images", train_images.shape))
print("%-28s %-s" % ("Test Images", test_images.shape))
print("%-28s %-s" % ("Train Labels", train_labels.shape))
print("%-28s %-s" % ("Test Labels", test_labels.shape))
print("=================================================================")

之后别忘了归一化哟:

print("Normalize images")
train_images = train_images / 255.0
test_images = test_images / 255.0

最后附上读取自定义数据的完整代码:

import os

import keras
import matplotlib.pyplot as plt
from PIL import Image
from keras.layers import *
from keras.models import *
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
re_load = False
re_build = False
# re_load = True
re_build = True
epochs = 50
batch_size = 5
count = 0
max_size = 2000000000

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持我们。

(0)

相关推荐

  • TensorFlow实现创建分类器

    本文实例为大家分享了TensorFlow实现创建分类器的具体代码,供大家参考,具体内容如下 创建一个iris数据集的分类器. 加载样本数据集,实现一个简单的二值分类器来预测一朵花是否为山鸢尾.iris数据集有三类花,但这里仅预测是否是山鸢尾.导入iris数据集和工具库,相应地对原数据集进行转换. # Combining Everything Together #---------------------------------- # This file will perform binary c

  • Tensorflow分类器项目自定义数据读入的实现

    在照着Tensorflow官网的demo敲了一遍分类器项目的代码后,运行倒是成功了,结果也不错.但是最终还是要训练自己的数据,所以尝试准备加载自定义的数据,然而demo中只是出现了fashion_mnist.load_data()并没有详细的读取过程,随后我又找了些资料,把读取的过程记录在这里. 首先提一下需要用到的模块: import os import keras import matplotlib.pyplot as plt from PIL import Image from keras

  • 完美解决TensorFlow和Keras大数据量内存溢出的问题

    内存溢出问题是参加kaggle比赛或者做大数据量实验的第一个拦路虎. 以前做的练手小项目导致新手产生一个惯性思维--读取训练集图片的时候把所有图读到内存中,然后分批训练. 其实这是有问题的,很容易导致OOM.现在内存一般16G,而训练集图片通常是上万张,而且RGB图,还很大,VGG16的图片一般是224x224x3,上万张图片,16G内存根本不够用.这时候又会想起--设置batch,但是那个batch的输入参数却又是图片,它只是把传进去的图片分批送到显卡,而我OOM的地方恰是那个"传进去&quo

  • Tensorflow 多线程与多进程数据加载实例

    在项目中遇到需要处理超级大量的数据集,无法载入内存的问题就不用说了,单线程分批读取和处理(虽然这个处理也只是特别简单的首尾相连的操作)也会使瓶颈出现在CPU性能上,所以研究了一下多线程和多进程的数据读取和预处理,都是通过调用dataset api实现 1. 多线程数据读取 第一种方法是可以直接从csv里读取数据,但返回值是tensor,需要在sess里run一下才能返回真实值,无法实现真正的并行处理,但如果直接用csv文件或其他什么文件存了特征值,可以直接读取后进行训练,可使用这种方法. imp

  • TensorFlow人工智能学习创建数据实现示例详解

    目录 一.数据创建 1.tf.constant() 2.tf.convert_to_tensor() 3.tf.zeros() 4.tf.fill() 二.数据随机初始化 ①tf.random.normal() ②tf.random.truncated_normal() ③tf.random.uniform() ④tf.random.shuffle() 一.数据创建 1.tf.constant() 创建自定义类型,自定义形状的数据,但不能创建类似于下面In [59]这样的,无法解释的数据. 2.

  • Spring中自定义数据类型转换的方法详解

    目录 类型转换服务 实现Converter接口 实现ConverterFactory接口 实现GenericConverter接口 环境:Spring5.3.12.RELEASE. Spring 3引入了一个core.onvert包,提供一个通用类型转换系统.系统定义了一个SPI来实现类型转换逻辑,以及一个API来在运行时执行类型转换.在Spring容器中,可以使用这个系统作为PropertyEditor实现的替代,将外部化的bean属性值字符串转换为所需的属性类型.还可以在应用程序中需要类型转

  • SpringBoot+Hibernate实现自定义数据验证及异常处理

    目录 前言 Hibernate实现字段校验 自定义校验注解 使用AOP处理校验异常 全局异常类处理异常 前言 在进行 SpringBoot 项目开发中,经常会碰到属性合法性问题,而面对这个问题通常的解决办法就是通过大量的 if 和 else 判断来解决的,例如: @PostMapping("/verify") @ResponseBody public Object verify(@Valid User user){ if (StringUtils.isEmpty(user.getNam

  • C#开发教程之利用特性自定义数据导出到Excel

    网上C#导出Excel的方法有很多.但用来用去感觉不够自动化.于是花了点时间,利用特性做了个比较通用的导出方法.只需要根据实体类,自动导出想要的数据 1.在NuGet上安装Aspose.Cells或者用微软自带类库也可以 2.需要导出的数据的实例类: using System.ComponentModel; using System.Reflection; using System.Runtime.Serialization; public class OrderReport { [Displa

  • C语言实现txt数据读入内存/CPU缓存实例详解

    摘要 C实现将txt数据读入内存/CPU缓存的函数,不多说,实现如下. 1. 实现代码 #include "stdafx.h" #include <stdio.h> #include <stdlib.h> int filelength(FILE *fp); char *readfile(char *path); int main(void){ char *string; string=readfile("C:/Users/Joe WANG/Deskto

  • php通过header发送自定义数据方法

    本文将介绍如何通过header发送自定义数据.发送请求时,除了可以使用$_GET/$_POST发送数据,也可以把数据放在header中传输过去. 发送header: 我们定义了三个参数,token.language.region,放入header发送过去 <?php $url = 'http://www.example.com'; $header = array('token:JxRaZezavm3HXM3d9pWnYiqqQC1SJbsU','language:zh','region:GZ')

  • python3+PyQt5重新实现自定义数据拖放处理

    本文分成两部分,第一部分通过python3+PyQt5实现自定义数据的拖放操作.第二部分则对第一部分的程序进行修改,增加拖放操作时,菜单提示是否移动或拷贝,还有可以通过ctrl键盘来设置移动过程中拷贝源而非会将源删除. 自定义数据MIME数据类型QMimeData,MIME是一种用于处理具有多个组成部分的自定义数据的标准化格式.MIME数据由一个数据类型和一个子类型构成–例如,text/plain,text/html,image/png,要处理自定义MIME数据,就必须要选用一种自定义数据类型和

随机推荐