pytorch加载自己的数据集源码分享

目录
  • 一、标准的数据集流程梳理
    • 数据来源
  • 二、实现加载自己的数据集
    • 1. 保存在txt文件中(生成训练集和测试集,其实这里的训练集以及测试集也都是用文本文件的形式保存下来的)
    • 2. 在继承dataset类LoadData的三个函数里调用train.txt以及test.txt实现相关功能
  • 三、源码

一、标准的数据集流程梳理

分为几个步骤
数据准备以及加载数据库–>数据加载器的调用或者设计–>批量调用进行训练或者其他作用

数据来源

直接读取了x和y的数据变量,对比后面的就从把对应的路径写进了文本文件中,通过加载器进行读取

x = torch.linspace(1, 10, 10)   # 训练数据 linspace返回一个一维的张量,(最小值,最大值,多少个数)
print(x)
y = torch.linspace(10, 1, 10)   # 标签
print(y)

将数据加载进数据库

输出的结果是<torch.utils.data.dataset.TensorDataset object at 0x00000145BD93F1C0>,需要使用加载器进行加载,才能迭代遍历

import torch.utils.data as Data
torch_dataset = Data.TensorDataset(x, y)  # 对给定的 tensor 数据,将他们包装成 dataset
#输出的结果是<torch.utils.data.dataset.TensorDataset object at 0x00000145BD93F1C0>,需要使用加载器进行加载,才能迭代遍历
print(torch_dataset)

所以要想看里面的内容,就需要用迭代进行操作或者查看。

BATCH_SIZE=5
loader = Data.DataLoader(#使用支持的默认的数据集加载的方式
    # 从数据库中每次抽出batch size个样本
    dataset=torch_dataset,       # torch TensorDataset format   加载数据集
    batch_size=BATCH_SIZE,       # mini batch size 5
    shuffle=False,                # 要不要打乱数据 (打乱比较好)
    num_workers=2,               # 多线程来读数据
)

def show_batch():
    for epoch in range(3):
        for step, (batch_x, batch_y) in enumerate(loader): #加载数据集的时候起的作用很奇怪
            # training
            print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))
            print("*"*100)
if __name__ == '__main__':
    show_batch()

二、实现加载自己的数据集

实现自己的数据集就需要完成对dataset类的重载。这个类的重载完成几个函数的作用

  • 初始化数据集中的数据以及标签__init__()
  • 返回数据和对应标签__getitem__
  • 返回数据集的大小__len__

基本的数据集的方法就是完成以上步骤,但是可以想想数据集通常是一些图片和标签组成,而这些数据集以及标签是保存在计算机上,具有相对应的位置,那么直接访问对应的位置因为是在文件夹下需要进行遍历等一系列操作,而且这就显得和dataset类没有解耦,因为有时候在这些位置的操作可能会有一些特殊操作,所以如果能够将其位置保存在文本文件中可能就会方便很多,所以就采取保存文本文件的方式。

# 自定义数据集类
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, *args):
        super().__init__()
        # 初始化数据集包含的数据和标签
        pass

    def __getitem__(self, index):
        # 根据索引index从文件中读取一个数据
        # 对数据预处理
        # 返回数据和对应标签
        pass

    def __len__(self):
        # 返回数据集的大小
        return len()

1. 保存在txt文件中(生成训练集和测试集,其实这里的训练集以及测试集也都是用文本文件的形式保存下来的)

所以这里新建一个数据库就是新建了两个文本文件,然后加载器通过文本文件就将图片以及label加载进去了。而标准的数据集操作是使用了自带的数据集接口,在加载的时候也不用再去实现相关的__getitem__方法

  • 数组定义
  • 将绝对路径加载进数组中
  • 数组定义
  • 将绝对路径加载进数组中
  • 通过os.walk操作
  • os.walk可以获得根路径、文件夹以及文件,并会一直进行迭代遍历下去,直至只有文件才会结束
  • 将数组的内容打乱顺序
  • 分别将绝对路径对应的数组内容写进文本文件里,那么这里的文本文件就是保存的数据库,其实数据就是一个保存相关信息或者其内容的文件,而标准也是将将其数据保存在了一个地方,然后对应到标准接口就可以加载了(Data.TensorDataset以及Data.DataLoader)

以下代码用于生成对应的train.txt val.txt

'''
生成训练集和测试集,保存在txt文件中
'''
import os
import random

train_ratio = 0.6

test_ratio = 1-train_ratio

rootdata = r"dataset"

#数组定义
train_list, test_list = [],[]
data_list = []

class_flag = -1
# 将绝对路径加载进数组中
for a,b,c in os.walk(rootdata):#os.walk可以获得根路径、文件夹以及文件,并会一直进行迭代遍历下去,直至只有文件才会结束
    print(a)
    for i in range(len(c)):
        data_list.append(os.path.join(a,c[i]))

    for i in range(0,int(len(c)*train_ratio)):
        train_data = os.path.join(a, c[i])+'\t'+str(class_flag)+'\n' #class_flag表示分类的类别
        train_list.append(train_data)

    for i in range(int(len(c) * train_ratio),len(c)):
        test_data = os.path.join(a, c[i]) + '\t' + str(class_flag)+'\n'
        test_list.append(test_data)

    class_flag += 1 

print(train_list)
# 将数组的内容打乱顺序
random.shuffle(train_list)
random.shuffle(test_list)

#分别将绝对路径对应的数组内容写进文本文件里
with open('train.txt','w',encoding='UTF-8') as f:
    for train_img in train_list:
        f.write(str(train_img))

with open('test.txt','w',encoding='UTF-8') as f:
    for test_img in test_list:
        f.write(test_img)

2. 在继承dataset类LoadData的三个函数里调用train.txt以及test.txt实现相关功能

初始化数据集中的数据以及标签、相关变量__init__()

def __init__(self, txt_path, train_flag=True):
     #初始化图片对应的变量imgs_info以及一些相关变量
     self.imgs_info = self.get_images(txt_path) #imgs_info保存了图片以及标签
     self.train_flag = train_flag

     self.train_tf = transforms.Compose([#对训练集的图片进行预处理
             transforms.Resize(224),
             transforms.RandomHorizontalFlip(),
             transforms.RandomVerticalFlip(),
             transforms.ToTensor(),
             transform_BZ
         ])
     self.val_tf = transforms.Compose([#对测试集的图片进行预处理
             transforms.Resize(224),
             transforms.ToTensor(),
             transform_BZ
         ])

返回数据对应标签__getitem__

def __getitem__(self, index):
     img_path, label = self.imgs_info[index]
     #打开图片,并将RGBA转换为RGB,这里是通过PIL库打开图片的
     img = Image.open(img_path)
     img = img.convert('RGB')
     img = self.padding_black(img) #将图片添加上黑边的
     if self.train_flag: #选择是训练集还是测试集
         img = self.train_tf(img)
     else:
         img = self.val_tf(img)
     label = int(label)

     return img, label

返回数据集的大小__len__

def __len__(self):
     return len(self.imgs_info)

由于前面已经对集成dataset的类进行了实现三种方法,那么就可以在加载器中进行加载,将加载后的数据传入到train函数或者test函数都可以

  • train_dataloader = DataLoader(dataset=train_data, num_workers=4, pin_memory=True, batch_size=batch_size, shuffle=True):使用加载器加载数据
  • train(train_dataloader, model, loss_fn, optimizer) test(test_dataloader, model):将数据传入train或者test中进行训练或者测试
  • 注意:LoadData是继承了dataset的类
if __name__=='__main__':
    batch_size = 16

    # # 给训练集和测试集分别创建一个数据集加载器
    train_data = LoadData("train.txt", True)
    valid_data = LoadData("test.txt", False)

    train_dataloader = DataLoader(dataset=train_data, num_workers=4, pin_memory=True, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(dataset=valid_data, num_workers=4, pin_memory=True, batch_size=batch_size)

    for X, y in test_dataloader:
        print("Shape of X [N, C, H, W]: ", X.shape)
        print("Shape of y: ", y.shape, y.dtype)
        break

三、源码

链接: https://pan.baidu.com/s/19Oo87gbcm9e8zvYGkBi95A 提取码: 2tss

到此这篇关于pytorch加载自己的数据集源码分享的文章就介绍到这了,更多相关pytorch加载自己的数据集内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • pytorch加载语音类自定义数据集的方法教程

    前言 pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合 torch.utils.data.Dataset:所有继承他的子类都应该重写  __len()__  , __getitem()__ 这两个方法 __len()__ :返回数据集中数据的数量 __getitem()__ :返回支持下标索引方式获取的一个数据 torch.utils.data.DataLoad

  • pytorch 自定义数据集加载方法

    pytorch 官网给出的例子中都是使用了已经定义好的特殊数据集接口来加载数据,而且其使用的数据都是官方给出的数据.如果我们有自己收集的数据集,如何用来训练网络呢?此时需要我们自己定义好数据处理接口.幸运的是pytroch给出了一个数据集接口类(torch.utils.data.Dataset),可以方便我们继承并实现自己的数据集接口. torch.utils.data torch的这个文件包含了一些关于数据集处理的类. class torch.utils.data.Dataset: 一个抽象类

  • pytorch加载自己的图片数据集的2种方法详解

    目录 ImageFolder 加载数据集 使用pytorch提供的Dataset类创建自己的数据集. Dataset加载数据集 总结 pytorch加载图片数据集有两种方法. 1.ImageFolder 适合于分类数据集,并且每一个类别的图片在同一个文件夹, ImageFolder加载的数据集, 训练数据为文件件下的图片, 训练标签是对应的文件夹, 每个文件夹为一个类别 导入ImageFolder()包 from torchvision.datasets import ImageFolder 在

  • Pytorch自己加载单通道图片用作数据集训练的实例

    pytorch 在torchvision包里面有很多的的打包好的数据集,例如minist,Imagenet-12,CIFAR10 和CIFAR100.在torchvision的dataset包里面,用的时候直接调用就行了.具体的调用格式可以去看文档(目前好像只有英文的).网上也有很多源代码. 不过,当我们想利用自己制作的数据集来训练网络模型时,就要有自己的方法了.pytorch在torchvision.dataset包里面封装过一个函数ImageFolder().这个函数功能很强大,只要你直接将

  • PyTorch加载数据集梯度下降优化

    目录 一.实现过程 1.准备数据 2.设计模型 3.构造损失函数和优化器 4.训练过程 5.结果展示 二.参考文献 一.实现过程 1.准备数据 与PyTorch实现多维度特征输入的逻辑回归的方法不同的是:本文使用DataLoader方法,并继承DataSet抽象类,可实现对数据集进行mini_batch梯度下降优化. 代码如下: import torch import numpy as np from torch.utils.data import Dataset,DataLoader clas

  • pytorch加载自己的图像数据集实例

    之前学习深度学习算法,都是使用网上现成的数据集,而且都有相应的代码.到了自己开始写论文做实验,用到自己的图像数据集的时候,才发现无从下手 ,相信很多新手都会遇到这样的问题. 参考文章https://www.jb51.net/article/177613.htm 下面代码实现了从文件夹内读取所有图片,进行归一化和标准化操作并将图片转化为tensor.最后读取第一张图片并显示. # 数据处理 import os import torch from torch.utils import data fr

  • PyTorch加载自己的数据集实例详解

    数据预处理在解决深度学习问题的过程中,往往需要花费大量的时间和精力. 数据处理的质量对训练神经网络来说十分重要,良好的数据处理不仅会加速模型训练, 更会提高模型性能.为解决这一问题,PyTorch提供了几个高效便捷的工具, 以便使用者进行数据处理或增强等操作,同时可通过并行化加速数据加载. 数据集存放大致有以下两种方式: (1)所有数据集放在一个目录下,文件名上附有标签名,数据集存放格式如下: root/cat_dog/cat.01.jpg root/cat_dog/cat.02.jpg ...

  • pytorch加载自己的数据集源码分享

    目录 一.标准的数据集流程梳理 数据来源 二.实现加载自己的数据集 1. 保存在txt文件中(生成训练集和测试集,其实这里的训练集以及测试集也都是用文本文件的形式保存下来的) 2. 在继承dataset类LoadData的三个函数里调用train.txt以及test.txt实现相关功能 三.源码 一.标准的数据集流程梳理 分为几个步骤数据准备以及加载数据库–>数据加载器的调用或者设计–>批量调用进行训练或者其他作用 数据来源 直接读取了x和y的数据变量,对比后面的就从把对应的路径写进了文本文件

  • Android用于加载xml的LayoutInflater源码超详细分析

    1.在view的加载和绘制流程中:文章链接 我们知道,定义在layout.xml布局中的view是通过LayoutInflate加载并解析成Java中对应的View对象的.那么具体的解析过程是哪样的. 先看onCreate方法,如果我们的Activity是继承自AppCompactActivity.android是通过getDelegate返回的对象setContentView,这个mDelegate 是AppCompatDelegateImpl的实例. @Override protected

  • Android图片加载利器之Picasso源码解析

    看到了这里,相信大家对Picasso的使用已经比较熟悉了,本篇博客中将从基本的用法着手,逐步的深入了解其设计原理. Picasso的代码量在众多的开源框架中算得上非常少的一个了,一共只有35个class文件,但是麻雀虽小,五脏俱全.好了下面跟随我的脚步,出发了. 基本用法 Picasso.with(this).load(imageUrl).into(imageView); with(this)方法 public static Picasso with(Context context) { if

  • Android实现基于滑动的SQLite数据分页加载技术(附demo源码下载)

    本文实例讲述了Android实现基于滑动的SQLite数据分页加载技术.分享给大家供大家参考,具体如下: main.xml如下: <menu xmlns:android="http://schemas.android.com/apk/res/android" > <item android:id="@+id/action_settings" android:orderInCategory="100" android:showAs

  • Pytorch如何加载自己的数据集(使用DataLoader读取Dataset)

    目录 1.Pytorch加载数据集会用到官方整理好的数据集 2.Dataset 3.DataLoader 4.查看数据 5.总结 1.Pytorch加载数据集会用到官方整理好的数据集 很多时候我们需要加载自己的数据集,这时候我们需要使用Dataset和DataLoader Dataset:是被封装进DataLoader里,实现该方法封装自己的数据和标签. DataLoader:被封装入DataLoaderIter里,实现该方法达到数据的划分. 2.Dataset 阅读源码后,我们可以指导,继承该

  • 使用pytorch加载并读取COCO数据集的详细操作

    目录 环境配置 基础知识:元祖.字典.数组 利用PyTorch读取COCO数据集 利用PyTorch读取自己制作的数据集 如何使用pytorch加载并读取COCO数据集 环境配置基础知识:元祖.字典.数组利用PyTorch读取COCO数据集利用PyTorch读取自己制作的数据集 环境配置 看pytorch入门教程 基础知识:元祖.字典.数组 # 元祖 a = (1, 2) # 字典 b = {'username': 'peipeiwang', 'code': '111'} # 数组 c = [1

随机推荐