使用PyTorch将文件夹下的图片分为训练集和验证集实例

PyTorch提供了ImageFolder的类来加载文件结构如下的图片数据集:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

使用这个类的问题在于无法将训练集(training dataset)和验证集(validation dataset)分开。我写了两个类来完成这个工作。

import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Resize, Compose
from PIL import Image
from sklearn.model_selection import train_test_split

class ImageFolderSplitter:
  # images should be placed in folders like:
  # --root
  # ----root\dogs
  # ----root\dogs\image1.png
  # ----root\dogs\image2.png
  # ----root\cats
  # ----root\cats\image1.png
  # ----root\cats\image2.png
  # path: the root of the image folder
  def __init__(self, path, train_size = 0.8):
    self.path = path
    self.train_size = train_size
    self.class2num = {}
    self.num2class = {}
    self.class_nums = {}
    self.data_x_path = []
    self.data_y_label = []
    self.x_train = []
    self.x_valid = []
    self.y_train = []
    self.y_valid = []
    for root, dirs, files in os.walk(path):
      if len(files) == 0 and len(dirs) > 1:
        for i, dir1 in enumerate(dirs):
          self.num2class[i] = dir1
          self.class2num[dir1] = i
      elif len(files) > 1 and len(dirs) == 0:
        category = ""
        for key in self.class2num.keys():
          if key in root:
            category = key
            break
        label = self.class2num[category]
        self.class_nums[label] = 0
        for file1 in files:
          self.data_x_path.append(os.path.join(root, file1))
          self.data_y_label.append(label)
          self.class_nums[label] += 1
      else:
        raise RuntimeError("please check the folder structure!")
    self.x_train, self.x_valid, self.y_train, self.y_valid = train_test_split(self.data_x_path, self.data_y_label, shuffle = True, train_size = self.train_size)

  def getTrainingDataset(self):
    return self.x_train, self.y_train

  def getValidationDataset(self):
    return self.x_valid, self.y_valid

class DatasetFromFilename(Dataset):
  # x: a list of image file full path
  # y: a list of image categories
  def __init__(self, x, y, transforms = None):
    super(DatasetFromFilename, self).__init__()
    self.x = x
    self.y = y
    if transforms == None:
      self.transforms = ToTensor()
    else:
      self.transforms = transforms

  def __len__(self):
    return len(self.x)

  def __getitem__(self, idx):
    img = Image.open(self.x[idx])
    img = img.convert("RGB")
    return self.transforms(img), torch.tensor([[self.y[idx]]])

# test code
# splitter = ImageFolderSplitter("for_test")
# transforms = Compose([Resize((51, 51)), ToTensor()])
# x_train, y_train = splitter.getTrainingDataset()
# training_dataset = DatasetFromFilename(x_train, y_train, transforms=transforms)
# training_dataloader = DataLoader(training_dataset, batch_size=2, shuffle=True)
# x_valid, y_valid = splitter.getValidationDataset()
# validation_dataset = DatasetFromFilename(x_valid, y_valid, transforms=transforms)
# validation_dataloader = DataLoader(validation_dataset, batch_size=2, shuffle=True)
# for x, y in training_dataloader:
#   print(x.shape, y.shape)

更多的代码可以在我的Github reop下找到。

(0)

相关推荐

  • Pytorch 数据加载与数据预处理方式

    数据加载分为加载torchvision.datasets中的数据集以及加载自己使用的数据集两种情况. torchvision.datasets中的数据集 torchvision.datasets中自带MNIST,Imagenet-12,CIFAR等数据集,所有的数据集都是torch.utils.data.Dataset的子类,都包含 _ _ len _ (获取数据集长度)和 _ getItem _ _ (获取数据集中每一项)两个子方法. Dataset源码如上,可以看到其中包含了两个没有实现的子

  • PyTorch读取Cifar数据集并显示图片的实例讲解

    首先了解一下需要的几个类所在的package from torchvision import transforms, datasets as ds from torch.utils.data import DataLoader import matplotlib.pyplot as plt import numpy as np #transform = transforms.Compose是把一系列图片操作组合起来,比如减去像素均值等. #DataLoader读入的数据类型是PIL.Image

  • pytorch中的自定义数据处理详解

    pytorch在数据中采用Dataset的数据保存方式,需要继承data.Dataset类,如果需要自己处理数据的话,需要实现两个基本方法. :.getitem:返回一条数据或者一个样本,obj[index] = obj.getitem(index). :.len:返回样本的数量 . len(obj) = obj.len(). Dataset 在data里,调用的时候使用 from torch.utils import data import os from PIL import Image 数

  • pytorch 数据处理:定义自己的数据集合实例

    数据处理 版本1 #数据处理 import os import torch from torch.utils import data from PIL import Image import numpy as np #定义自己的数据集合 class DogCat(data.Dataset): def __init__(self,root): #所有图片的绝对路径 imgs=os.listdir(root) self.imgs=[os.path.join(root,k) for k in imgs

  • pytorch 实现将自己的图片数据处理成可以训练的图片类型

    为了使用自己的图像数据,需要仿照pytorch数据输入创建新的类,其中数据格式为numpy.ndarray. 将自己的图片保存到numpy.ndarray中,然后创建类 from torch.utils.data import Dataset import numpy as np class Dataset(Dataset): def __init__(self, path_img, path_target, transforms=None): self.train = path_img sel

  • 使用PyTorch将文件夹下的图片分为训练集和验证集实例

    PyTorch提供了ImageFolder的类来加载文件结构如下的图片数据集: root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png 使用这个类的问题在于无法将训练集(training dataset)和验证集(validation dataset)分开.我写了两个类来完成这个工作. import os import torch fro

  • python或C++读取指定文件夹下的所有图片

    本文实例为大家分享了python或C++读取指定文件夹下的所有图片,供大家参考,具体内容如下 1.python读取指定文件夹下的所有图片路径和图片文件名 import cv2 from os import walk,path def get_fileNames(rootdir): data=[] prefix = [] for root, dirs, files in walk(rootdir, topdown=True): for name in files: pre, ending = pa

  • 微信小程序云开发获取文件夹下所有文件(推荐)

    上周一个高中同学让我帮他做个图片展示的公众号,因为一直在加班的原因,所以一时忘了,昨晚想起来就赶紧加班加点的帮他弄了下,遇到了个问题,记录一下. 他的需求是要有个后台给他上传图片并且将图片归类,前端公众号根据每次不同的主题显示不同的封面和照片,但是服务器又不想买贵的,思来想去,最后决定用小程序云开发,连服务器和公众号认证费都免了,哈哈哈哈 因为改成小程序云开发,所以需求也有了些变动,最后改成不需要后台,图片直接在云开发控制台中上传,然后在小程序中添加一个专门用来对封面和名称做修改的管理员页面,其

  • 使用python如何删除同一文件夹下相似的图片

    前言 最近整理图片发现,好多图片都非常相似,于是写如下代码去删除,有两种方法: 注:第一种方法只对于连续图片(例一个视频里截下的图片)准确率也较高,其效率高:第二种方法准确率高,但效率低 方法一:相邻两个文件比较相似度,相似就把第二个加到新列表里,然后进行新列表去重,统一删除. 例如:有文件1-10,首先1和2相比较,若相似,则把2加入到新列表里,再接着2和3相比较,若不相似,则继续进行3和4比较-一直比到最后,然后删除新列表里的图片 代码如下: #!/usr/bin/env python #

  • linux下采用shell脚本实现批量为指定文件夹下图片添加水印的方法

    要实现linux下采用shell脚本批量为指定文件夹下图片添加水印,首先需要安装imagemagick: CentOS上安装: yum install ImageMagick -y Debian上安装: apt-get install ImageMagick -y 脚本: #!/bin/bash for each in /要处理的图片目录/*{.jpg,.gif} s=`du -k $each | awk '{print $1}'` if [ $s -gt 10 ]; then #convert

  • php随机显示指定文件夹下图片的方法

    本文实例讲述了php随机显示指定文件夹下图片的方法.分享给大家供大家参考.具体如下: 此代码会从指定的服务器文件夹随机选择一个图片进行显示,非常有用,图片格式为.gif,.jpg,.png <?php //This will get an array of all the gif, jpg and png images in a folder $img_array = glob("/path/to/images/*.{gif,jpg,png}",GLOB_BRACE); //Pi

  • 利用Python对文件夹下图片数据进行批量改名的代码实例

    1. 前言 我们最近在做一个使用flask 模拟 instagram 的图片分享网站, 需要一些基本的图片数据, 我们这里采用的是本地提供, 但是,使用爬虫从网上爬下来的图片,名字都是乱七八糟的,不利于编程,这里就需要对他们进行批量改名操作. 2. 基本思路 使用python 的os 模块,对文件夹进行遍历(listdir), 同时使用rename 进行改名操作 3. 实现效果 4. 实现代码 代码非常简单 # -*- coding:utf8 -*- import os class BatchR

  • asp.net 获取文件夹中的图片的代码

    前台: 复制代码 代码如下: <asp:DataList ID="DataList1" runat="server" RepeatDirection="Horizontal" RepeatColumns="5" CellSpacing="25"> <ItemTemplate> <img src="<%# Eval("FullName") %&

  • Java Servlet上传图片到指定文件夹并显示图片

    在学习Servlet过程中,针对图片上传做了一个Demo,实现的功能是:在a页面上传图片,点击提交后,将图片保存到服务器指定路径(D:/image):跳转到b页面,b页面读取展示绝对路径(D:/image)的图片.主要步骤如下: 步骤一:上传页面uploadphoto.jsp 需要注意两个问题: 1.form 的method必须是post的,get不能上传文件, 还需要加上enctype="multipart/form-data" 表示提交的数据是二进制文件. 2.需要提供type=&

  • Python批量重命名同一文件夹下文件的方法

    本文实例讲述了Python批量重命名同一文件夹下文件的方法.分享给大家供大家参考.具体分析如下: 朋友发了一个文件夹过来,里面的图片都以 .tmp 为后缀. 手工修改的话工作量太大.故写了一个 Python 脚本进行批量重命名. 对 Python 的标准库不熟,只能边查资料,或者 help() 边写代码. 三行代码就可以解决这一问题. 不过没有捕获异常.不能迭代同一目录下的所有文件. 代码如下: import os for file in os.listdir("."): if os.

随机推荐