pytorch实现建立自己的数据集(以mnist为例)

本文将原始的numpy array数据在pytorch下封装为Dataset类的数据集,为后续深度网络训练提供数据。

加载并保存图像信息

首先导入需要的库,定义各种路径。

import os
import matplotlib
from keras.datasets import mnist
import numpy as np
from torch.utils.data.dataset import Dataset
from PIL import Image
import scipy.misc

root_path = 'E:/coding_ex/pytorch/Alexnet/data/'
base_path = 'baseset/'
training_path = 'trainingset/'
test_path = 'testset/'

这里将数据集分为三类,baseset为所有数据(trainingset+testset),trainingset是训练集,testset是测试集。直接通过keras.dataset加载mnist数据集,不能自动下载的话可以手动下载.npz并保存至相应目录下。

def LoadData(root_path, base_path, training_path, test_path):
  (x_train, y_train), (x_test, y_test) = mnist.load_data()
  x_baseset = np.concatenate((x_train, x_test))
  y_baseset = np.concatenate((y_train, y_test))
  train_num = len(x_train)
  test_num = len(x_test)

  #baseset
  file_img = open((os.path.join(root_path, base_path)+'baseset_img.txt'),'w')
  file_label = open((os.path.join(root_path, base_path)+'baseset_label.txt'),'w')
  for i in range(train_num + test_num):
    file_img.write(root_path + base_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_baseset[i])+'\n') #label
#    scipy.misc.imsave(root_path + base_path + '/img/'+str(i) + '.png', x_baseset[i])
    matplotlib.image.imsave(root_path + base_path + 'img/'+str(i) + '.png', x_baseset[i])
  file_img.close()
  file_label.close()

  #trainingset
  file_img = open((os.path.join(root_path, training_path)+'trainingset_img.txt'),'w')
  file_label = open((os.path.join(root_path, training_path)+'trainingset_label.txt'),'w')
  for i in range(train_num):
    file_img.write(root_path + training_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_train[i])+'\n') #label
#    scipy.misc.imsave(root_path + training_path + '/img/'+str(i) + '.png', x_train[i])
    matplotlib.image.imsave(root_path + training_path + 'img/'+str(i) + '.png', x_train[i])
  file_img.close()
  file_label.close()

  #testset
  file_img = open((os.path.join(root_path, test_path)+'testset_img.txt'),'w')
  file_label = open((os.path.join(root_path, test_path)+'testset_label.txt'),'w')
  for i in range(test_num):
    file_img.write(root_path + test_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_test[i])+'\n') #label
#    scipy.misc.imsave(root_path + test_path + '/img/'+str(i) + '.png', x_test[i])
    matplotlib.image.imsave(root_path + test_path + 'img/'+str(i) + '.png', x_test[i])
  file_img.close()
  file_label.close()

使用这段代码时,需要建立相应的文件夹及.txt文件,./data文件夹结构如下:

/img文件夹

由于mnist数据集其实是灰度图,这里用matplotlib保存的图像是伪彩色图像。

如果用scipy.misc.imsave的话保存的则是灰度图像。

xxx_img.txt文件

xxx_img.txt文件中存放的是每张图像的名字

xxx_label.txt文件

xxx_label.txt文件中存放的是类别标记

这里记得保存的时候一行为一个图像信息,便于后续读取。

定义自己的Dataset类

pytorch训练数据时需要数据集为Dataset类,便于迭代等等,这里将加载保存之后的数据封装成Dataset类,继承该类需要写初始化方法(__init__),获取指定下标数据的方法__getitem__),获取数据个数的方法(__len__)。这里尤其需要注意的是要把label转为LongTensor类型的。

class DataProcessingMnist(Dataset):
  def __init__(self, root_path, imgfile_path, labelfile_path, imgdata_path, transform = None):
    self.root_path = root_path
    self.transform = transform
    self.imagedata_path = imgdata_path
    img_file = open((root_path + imgfile_path),'r')
    self.image_name = [x.strip() for x in img_file]
    img_file.close()
    label_file = open((root_path + labelfile_path), 'r')
    label = [int(x.strip()) for x in label_file]
    label_file.close()
    self.label = torch.LongTensor(label)#这句很重要,一定要把label转为LongTensor类型的

  def __getitem__(self, idx):
    image = Image.open(str(self.image_name[idx]))
    image = image.convert('RGB')
    if self.transform is not None:
      image = self.transform(image)
    label = self.label[idx]
    return image, label
  def __len__(self):
    return len(self.image_name)

定义完自己的类之后可以测试一下。

  LoadData(root_path, base_path, training_path, test_path)
  training_imgfile = training_path + 'trainingset_img.txt'
  training_labelfile = training_path + 'trainingset_label.txt'
  training_imgdata = training_path + 'img/'
  #实例化一个类
  dataset = DataProcessingMnist(root_path, training_imgfile, training_labelfile, training_imgdata)

得到图像名称

name = dataset.image_name

这里我们可以单独输出某一个名称看一下是否有换行符

print(name[0])
>>>'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png'

如果定义类的时候self.image_name = [x.strip() for x in img_file]这句没有strip掉,则输出的值将为'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png\n'

获取固定下标的图像

im, label = dataset.__getitem__(0)

得到结果

以上这篇pytorch实现建立自己的数据集(以mnist为例)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • 详解PyTorch手写数字识别(MNIST数据集)

    MNIST 手写数字识别是一个比较简单的入门项目,相当于深度学习中的 Hello World,可以让我们快速了解构建神经网络的大致过程.虽然网上的案例比较多,但还是要自己实现一遍.代码采用 PyTorch 1.0 编写并运行. 导入相关库 import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, t

  • pytorch中的自定义数据处理详解

    pytorch在数据中采用Dataset的数据保存方式,需要继承data.Dataset类,如果需要自己处理数据的话,需要实现两个基本方法. :.getitem:返回一条数据或者一个样本,obj[index] = obj.getitem(index). :.len:返回样本的数量 . len(obj) = obj.len(). Dataset 在data里,调用的时候使用 from torch.utils import data import os from PIL import Image 数

  • pytorch 自定义数据集加载方法

    pytorch 官网给出的例子中都是使用了已经定义好的特殊数据集接口来加载数据,而且其使用的数据都是官方给出的数据.如果我们有自己收集的数据集,如何用来训练网络呢?此时需要我们自己定义好数据处理接口.幸运的是pytroch给出了一个数据集接口类(torch.utils.data.Dataset),可以方便我们继承并实现自己的数据集接口. torch.utils.data torch的这个文件包含了一些关于数据集处理的类. class torch.utils.data.Dataset: 一个抽象类

  • pytorch下大型数据集(大型图片)的导入方式

    使用torch.utils.data.Dataset类 处理图片数据时, 1. 我们需要定义三个基本的函数,以下是基本流程 class our_datasets(Data.Dataset): def __init__(self,root,is_resize=False,is_transfrom=False): #这里只是个参考.按自己需求写. self.root=root self.is_resize=is_resize self.is_transfrom=is_transfrom self.i

  • Pytorch 神经网络—自定义数据集上实现教程

    第一步.导入需要的包 import os import scipy.io as sio import numpy as np import torch import torch.nn as nn import torch.backends.cudnn as cudnn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from torchvision import transforms, ut

  • Pytorch 实现数据集自定义读取

    以读取VOC2012语义分割数据集为例,具体见代码注释: VocDataset.py from PIL import Image import torch import torch.utils.data as data import numpy as np import os import torchvision import torchvision.transforms as transforms import time #VOC数据集分类对应颜色标签 VOC_COLORMAP = [[0,

  • pytorch实现建立自己的数据集(以mnist为例)

    本文将原始的numpy array数据在pytorch下封装为Dataset类的数据集,为后续深度网络训练提供数据. 加载并保存图像信息 首先导入需要的库,定义各种路径. import os import matplotlib from keras.datasets import mnist import numpy as np from torch.utils.data.dataset import Dataset from PIL import Image import scipy.misc

  • 计算pytorch标准化(Normalize)所需要数据集的均值和方差实例

    pytorch做标准化利用transforms.Normalize(mean_vals, std_vals),其中常用数据集的均值方差有: if 'coco' in args.dataset: mean_vals = [0.471, 0.448, 0.408] std_vals = [0.234, 0.239, 0.242] elif 'imagenet' in args.dataset: mean_vals = [0.485, 0.456, 0.406] std_vals = [0.229,

  • PyTorch加载自己的数据集实例详解

    数据预处理在解决深度学习问题的过程中,往往需要花费大量的时间和精力. 数据处理的质量对训练神经网络来说十分重要,良好的数据处理不仅会加速模型训练, 更会提高模型性能.为解决这一问题,PyTorch提供了几个高效便捷的工具, 以便使用者进行数据处理或增强等操作,同时可通过并行化加速数据加载. 数据集存放大致有以下两种方式: (1)所有数据集放在一个目录下,文件名上附有标签名,数据集存放格式如下: root/cat_dog/cat.01.jpg root/cat_dog/cat.02.jpg ...

  • pytorch学习教程之自定义数据集

    自定义数据集 在训练深度学习模型之前,样本集的制作非常重要.在pytorch中,提供了一些接口和类,方便我们定义自己的数据集合,下面完整的试验自定义样本集的整个流程. 开发环境 Ubuntu 18.04 pytorch 1.0 pycharm 实验目的 掌握pytorch中数据集相关的API接口和类 熟悉数据集制作的整个流程 实验过程 1.收集图像样本 以简单的猫狗二分类为例,可以在网上下载一些猫狗图片.创建以下目录: data-------------根目录 data/test-------测

  • Pytorch中使用ImageFolder读取数据集时忽略特定文件

    目录 一.使用ImageFolder读取数据集时忽略特定文件 二.ImageFolder只读取部分类别文件夹 一.使用ImageFolder读取数据集时忽略特定文件 如果事先知道需要忽略哪些文件,当然直接从数据集里删除就行了.但如果需要在程序运行时动态确认,或者筛选规则比较复杂,人工不好做,就需要让ImageFolder在读取时使用自定义的筛选规则. ImageFolder有一个可选参数为is_valid_file,参数类型为可调用的函数,该函数传入一个str参数,返回一个bool值.当返回值为

  • 聊聊基于pytorch实现Resnet对本地数据集的训练问题

    目录 1.dataset.py(先看代码的总体流程再看介绍) 2.network.py 3.train.py 4.结果与总结 本文是使用pycharm下的pytorch框架编写一个训练本地数据集的Resnet深度学习模型,其一共有两百行代码左右,分成mian.py.network.py.dataset.py以及train.py文件,功能是对本地的数据集进行分类.本文介绍逻辑是总分形式,即首先对总流程进行一个概括,然后分别介绍每个流程中的实现过程(代码+流程图+文字的介绍). 对于整个项目的流程首

  • pytorch加载自己的数据集源码分享

    目录 一.标准的数据集流程梳理 数据来源 二.实现加载自己的数据集 1. 保存在txt文件中(生成训练集和测试集,其实这里的训练集以及测试集也都是用文本文件的形式保存下来的) 2. 在继承dataset类LoadData的三个函数里调用train.txt以及test.txt实现相关功能 三.源码 一.标准的数据集流程梳理 分为几个步骤数据准备以及加载数据库–>数据加载器的调用或者设计–>批量调用进行训练或者其他作用 数据来源 直接读取了x和y的数据变量,对比后面的就从把对应的路径写进了文本文件

  • Pytorch转keras的有效方法,以FlowNet为例讲解

    Pytorch凭借动态图机制,获得了广泛的使用,大有超越tensorflow的趋势,不过在工程应用上,TF仍然占据优势.有的时候我们会遇到这种情况,需要把模型应用到工业中,运用到实际项目上,TF支持的PB文件和TF的C++接口就成为了有效的工具.今天就给大家讲解一下Pytorch转成Keras的方法,进而我们也可以获得Pb文件,因为Keras是支持tensorflow的,我将会在下一篇博客讲解获得Pb文件,并使用Pb文件的方法. Pytorch To Keras 首先,我们必须有清楚的认识,网上

  • PyTorch如何创建自己的数据集

    目录 PyTorch创建自己的数据集 pytorch常用数据集的使用 PyTorch创建自己的数据集 图片文件在同一的文件夹下 思路是继承 torch.utils.data.Dataset,并重点重写其 __getitem__方法,示例代码如下: class ImageFolder(Dataset):     def __init__(self, folder_path):         self.files = sorted(glob.glob('%s/*.*' % folder_path)

  • 详解如何从TensorFlow的mnist数据集导出手写体数字图片

    在TensorFlow的官方入门课程中,多次用到mnist数据集. mnist数据集是一个数字手写体图片库,但它的存储格式并非常见的图片格式,所有的图片都集中保存在四个扩展名为idx3-ubyte的二进制文件. 如果我们想要知道大名鼎鼎的mnist手写体数字都长什么样子,就需要从mnist数据集中导出手写体数字图片.了解这些手写体的总体形状,也有助于加深我们对TensorFlow入门课程的理解. 下面先给出通过TensorFlow api接口导出mnist手写体数字图片的python代码,再对代

随机推荐