pytorch加载自己的图片数据集的2种方法详解

目录
  • ImageFolder 加载数据集
  • 使用pytorch提供的Dataset类创建自己的数据集。
  • Dataset加载数据集
  • 总结

pytorch加载图片数据集有两种方法。

1.ImageFolder 适合于分类数据集,并且每一个类别的图片在同一个文件夹, ImageFolder加载的数据集, 训练数据为文件件下的图片, 训练标签是对应的文件夹, 每个文件夹为一个类别

导入ImageFolder()包
from torchvision.datasets import ImageFolder

在Flower_Orig_dataset文件夹下有flower_orig 和 sunflower这两个文件夹, 这两个文件夹下放着同一个类别的图片。 使用 ImageFolder 加载的图片, 就会返回图片信息和对应的label信息, 但是label信息是根据文件夹给出的, 如flower_orig就是标签0, sunflower就是标签1。

ImageFolder 加载数据集

1. 导入包和设置transform

import torch
from torchvision import transforms, datasets
import torch.nn as nn
from torch.utils.data import DataLoader

transforms = transforms.Compose([
    transforms.Resize(256),    # 将图片短边缩放至256,长宽比保持不变:
    transforms.CenterCrop(224),   #将图片从中心切剪成3*224*224大小的图片
    transforms.ToTensor()          #把图片进行归一化,并把数据转换成Tensor类型
]) 

2.加载数据集: 将分类图片的父目录作为路径传递给ImageFolder(), 并传入transform。这样就有了要加载的数据集, 之后就可以使用DataLoader加载数据, 并构建网络训练。

path = r'D:\数据集\Flower_Orig_dataset'
data_train = datasets.ImageFolder(path, transform=transforms)
data_loader = DataLoader(data_train, batch_size=64, shuffle=True)
for i, data in enumerate(data_loader):
    images, labels = data
    print(images.shape)
    print(labels.shape)
    break

使用pytorch提供的Dataset类创建自己的数据集。

具体步骤:

1.  首先要有一个txt文件, 这个文件格式是: 图片路径  标签.  这样的格式, 所以使用os库, 遍历自己的图片名, 并把标签和图片路径写入txt文件。

2. 有了这个txt文件, 我们就可以在类里面构造我们的数据集.

2.1    把图片路径和图片标签分割开, 有两个列表, 一个列表是图片路径名, 一个列表是标签号, 有一点就是第 i 个图片列表和 第 i 个标签是对应的

3. 重写__len__方法  和  __getitem__方法

3.1 getitem方法中, 获得对应的图片路径,并用PIL库读取文件把图片transfrom后, 在getitem函数中返回读取的图片和标签即可

4.就可以构建数据集实例和加载数据集.

定义一个用来生成[ 图片路径 标签] 这样的txt文件函数

def make_txt(root, file_name, label):
    path = os.path.join(root, file_name)
    data = os.listdir(path)
    f = open(path+'\\'+'f.txt', 'w')
    for line in data:
        f.write(line+' '+str(label)+'\n')
    f.close()
#调用函数生成两个文件夹下的txt文件
make_txt(path, file_name='flower_orig', label=0)
make_txt(path, file_name='sunflower', label=1)

将连个txt文件合并成一个txt文件,表示数据集所有的图片和标签

def link_txt(file1, file2):
    txt_list = []
    path = r'D:\数据集\Flower_Orig_dataset\data.txt'

    f = open(path, 'a')

    f1 = open(file1, 'r')
    data1 = f1.readlines()
    for line in data1:
        txt_list.append(line)

    f2 = open(file2, 'r')
    data2 = f2.readlines()
    for line in data2:
        txt_list.append(line)

    for line in txt_list:
        f.write(line)

    f.close()
    f1.close()
    f2.close()

#调用函数, 将两个文件夹下的txt文件合并
file1 = r'D:\数据集\Flower_Orig_dataset\flower_orig\f.txt'
file2 = r'D:\数据集\Flower_Orig_dataset\sunflower\f.txt'
link_txt(file1=file1, file2=file2)

现在我们已经有了我们制作数据集所需要的txt文件, 接下来要做的即使继承Dataset类, 来构建自己的数据集 , 别忘了前面说的 构建数据集步骤, 在__getitem__函数中, 需要拿到图片路径和标签, 并且用PIL库方法读取图片,对图片进行transform转换后,返回图片信息和标签信息

Dataset加载数据集

我们读取图片的根目录, 在根目录下有所有图片的txt文件, 拿到txt文件后, 先读取txt文件, 之后遍历txt文件中的每一行, 首先去除掉尾部的换行符, 在以空格切分,前半部分是图片名称, 后半部分是图片标签, 当图片名称和根目录结合,就得到了我们的图片路径
class MyDataset(Dataset):
    def __init__(self, img_path, transform=None):
        super(MyDataset, self).__init__()
        self.root = img_path

        self.txt_root = self.root + 'data.txt'
        f = open(self.txt_root, 'r')
        data = f.readlines()

        imgs = []
        labels = []
        for line in data:
            line = line.rstrip()
            word = line.split()
            imgs.append(os.path.join(self.root, word[1], word[0]))

            labels.append(word[1])
        self.img = imgs
        self.label = labels
        self.transform = transform

    def __len__(self):
        return len(self.label)

    def __getitem__(self, item):
        img = self.img[item]
        label = self.label[item]

        img = Image.open(img).convert('RGB')

        #此时img是PIL.Image类型   label是str类型

        if transforms is not None:
            img = self.transform(img)

        label = np.array(label).astype(np.int64)
        label = torch.from_numpy(label)

        return img, label

加载我们的数据集:

path = r'D:\数据集\Flower_Orig_dataset'
dataset = MyDataset(path, transform=transform)

data_loader = DataLoader(dataset=dataset, batch_size=64, shuffle=True)

接下来我们就可以构建我们的网络架构:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3,16,3)
        self.maxpool = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(16,5,3)

        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(55*55*5, 1200)
        self.fc2 = nn.Linear(1200,64)
        self.fc3 = nn.Linear(64,2)

    def forward(self,x):
        x = self.maxpool(self.relu(self.conv1(x)))    #113
        x = self.maxpool(self.relu(self.conv2(x)))    #55
        x = x.view(-1, self.num_flat_features(x))
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s

        return num_features
 

训练我们的网络:

model = Net()

criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

epochs = 10
for epoch in range(epochs):
    running_loss = 0.0
    for i, data in enumerate(data_loader):
        images, label = data

        out = model(images)

        loss = criterion(out, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if(i+1)%10 == 0:
            print('[%d  %5d]   loss: %.3f'%(epoch+1, i+1, running_loss/100))
            running_loss = 0.0

print('finished train')

保存网络模型(这里不止是保存参数,还保存了网络结构)

#保存模型
torch.save(net, 'model_name.pth')   #保存的是模型, 不止是w和b权重值

# 读取模型
model = torch.load('model_name.pth')

总结

到此这篇关于pytorch加载自己的图片数据集的2种方法的文章就介绍到这了,更多相关pytorch加载图片数据集内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • pytorch 数据集图片显示方法

    图片显示 pytorch 载入的数据集是元组tuple 形式,里面包括了数据及标签(train_data,label),其中的train_data数据可以转换为torch.Tensor形式,方便后面计算使用. 同样给一些刚入门的同学在使用载入的数据显示图片的时候带来一些难以理解的地方,这里主要是将Tensor与numpy转换的过程,理解了这些就可以就行转换了 CIAFA10数据集 首先载入数据集,这里做了一些数据处理,包括图片尺寸.数据归一化等 import torch from torch.a

  • PyTorch读取Cifar数据集并显示图片的实例讲解

    首先了解一下需要的几个类所在的package from torchvision import transforms, datasets as ds from torch.utils.data import DataLoader import matplotlib.pyplot as plt import numpy as np #transform = transforms.Compose是把一系列图片操作组合起来,比如减去像素均值等. #DataLoader读入的数据类型是PIL.Image

  • Pytorch自己加载单通道图片用作数据集训练的实例

    pytorch 在torchvision包里面有很多的的打包好的数据集,例如minist,Imagenet-12,CIFAR10 和CIFAR100.在torchvision的dataset包里面,用的时候直接调用就行了.具体的调用格式可以去看文档(目前好像只有英文的).网上也有很多源代码. 不过,当我们想利用自己制作的数据集来训练网络模型时,就要有自己的方法了.pytorch在torchvision.dataset包里面封装过一个函数ImageFolder().这个函数功能很强大,只要你直接将

  • pytorch加载自己的图片数据集的2种方法详解

    目录 ImageFolder 加载数据集 使用pytorch提供的Dataset类创建自己的数据集. Dataset加载数据集 总结 pytorch加载图片数据集有两种方法. 1.ImageFolder 适合于分类数据集,并且每一个类别的图片在同一个文件夹, ImageFolder加载的数据集, 训练数据为文件件下的图片, 训练标签是对应的文件夹, 每个文件夹为一个类别 导入ImageFolder()包 from torchvision.datasets import ImageFolder 在

  • pytorch加载自己的图像数据集实例

    之前学习深度学习算法,都是使用网上现成的数据集,而且都有相应的代码.到了自己开始写论文做实验,用到自己的图像数据集的时候,才发现无从下手 ,相信很多新手都会遇到这样的问题. 参考文章https://www.jb51.net/article/177613.htm 下面代码实现了从文件夹内读取所有图片,进行归一化和标准化操作并将图片转化为tensor.最后读取第一张图片并显示. # 数据处理 import os import torch from torch.utils import data fr

  • 使用pytorch加载并读取COCO数据集的详细操作

    目录 环境配置 基础知识:元祖.字典.数组 利用PyTorch读取COCO数据集 利用PyTorch读取自己制作的数据集 如何使用pytorch加载并读取COCO数据集 环境配置基础知识:元祖.字典.数组利用PyTorch读取COCO数据集利用PyTorch读取自己制作的数据集 环境配置 看pytorch入门教程 基础知识:元祖.字典.数组 # 元祖 a = (1, 2) # 字典 b = {'username': 'peipeiwang', 'code': '111'} # 数组 c = [1

  • pytorch加载语音类自定义数据集的方法教程

    前言 pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合 torch.utils.data.Dataset:所有继承他的子类都应该重写  __len()__  , __getitem()__ 这两个方法 __len()__ :返回数据集中数据的数量 __getitem()__ :返回支持下标索引方式获取的一个数据 torch.utils.data.DataLoad

  • Python计算图片数据集的均值方差示例详解

    目录 前言 Python批量reshape图片 参考 计算数据集均值和方差 前言 在做图像处理的时候,有时候需要得到整个数据集的均值方差数值,以下代码可以解决你的烦恼: (做这个之前一定保证所有的图片都是统一尺寸,不然算出来不对,我的代码里设计的是512*512,可以自己调整,同一尺寸的代码我也有: Python批量reshape图片 # -*- coding: utf-8 -*- """ Created on Thu Aug 23 16:06:35 2018 @author

  • Python实现普通图片转ico图标的方法详解

    目录 简介 历史攻略 下载安装包 下载地址 安装后缀pythonmagick - whl文件 案例源码 效果图 简介 ICO是一种图标文件格式,图标文件可以存储单个图案.多尺寸.多色板的图标文件.一个图标实际上是多张不同格式的图片的集合体,并且还包含了一定的透明区域.它是图标文件格式的一种,可以存储单个图案.多尺寸.多色板的图标文件.图标是具有明确指代含义的计算机图形.其中桌面图标是软件标识,界面中的图标是功能标识. 历史攻略 pip安装第三方库全攻略:普通安装.安装whl后缀文件.使用国内镜像

  • C#动态加载组件后如何在开发环境中调试详解

    动态加载组件 那就是简单的Assembly.Load动态加载dll而以.这网上资料也有不少.基本的思路基本上就是在本地上一个指定目录如[plugs]存在着一堆dll文件.主程序在初始运行时一般会把指定目录下的dll一次性用Assembly.Load加载进来.只要把指定目录变成从网络加载,或者加载指定目录前先检查网络上的是否有新版本.这就简单做成个最简单版本的热更新. 多数网上的资料就是然后就没有然后了.很多人就发现产品是通过动态加载组件了.但开发人员根本无法调试啊.不能调试就意味着开发难度大啊.

  • Java Swing实现窗体添加背景图片的2种方法详解

    本文实例讲述了Java Swing实现窗体添加背景图片的2种方法.分享给大家供大家参考,具体如下: 在美化程序时,常常需要在窗体上添加背景图片.通过搜索和测试,发现了2种有效方式.下面分别介绍. 1. 利用JLabel加载图片 利用JLabel自带的setIcon(Icon icon)加载icon,并设置JLabel对象的位置和大小使其完全覆盖窗体.这是一个很取巧的办法,代码非常简单,如下所示. JLabel lbBg = new JLabel(imageIcon); lbBg.setBound

  • 利用canvas中toDataURL()将图片转为dataURL(base64)的方法详解

    将图片转为base64的好处 将图片转换为Base64编码,可以让你很方便地在没有上传文件的条件下将图片插入其它的网页.编辑器中. 这对于一些小的图片是极为方便的,因为你不需要再去寻找一个保存图片的地方. 将图片转换成base64编码的,在web网上一般用于小图片上,不仅可以减少图片的请求数量(集合到js.css代码中),还可以防止因为一些相对路径等问题导致图片404错误. 引言 假设一个应用场景:由于某些特殊原因从服务端请求到图片路径,要求通过该路径获取对应图片的 base64 dataURL

  • JS 图片压缩原理与实现方法详解

    本文实例讲述了JS 图片压缩原理与实现方法.分享给大家供大家参考,具体如下: 前言 说起图片压缩,大家想到的或者平时用到的很多工具都可以实现,例如,客户端类的有图片压缩工具 PPDuck3, JS 实现类的有插件 compression.js ,亦或是在线处理类的 OSS 上传,文件上传后,在访问文件时中也有图片的压缩配置选项,不过,能不能自己撸一套 JS 实现的图片压缩代码呢?当然可以,那我们先来理一下思路. 压缩思路 涉及到 JS 的图片压缩,我的想法是需要用到 Canvas 的绘图能力,通

随机推荐