pytorch从csv加载自定义数据模板的操作

整理了一套模板,全注释了,这个难点终于克服了

from PIL import Image
import pandas as pd
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import os
#放文件的路径
dir_path= './97/train/'
csv_path='./97/train.csv'
class Mydataset(Dataset):
 #传递数据路径,csv路径 ,数据增强方法
 def __init__(self, dir_path,csv, transform=None, target_transform=None):
  super(Mydataset, self).__init__()
  #一个个往列表里面加绝对路径
  self.path = []
  #读取csv
  self.data = pd.read_csv(csv)
  #对标签进行硬编码,例如0 1 2 3 4,把字母变成这个
  colorMap = {elem: index + 1 for index, elem in enumerate(set(self.data["label"]))}
  self.data['label'] = self.data['label'].map(colorMap)
  #创造空的label准备存放标签
  self.num = int(self.data.shape[0]) # 一共多少照片
  self.label = np.zeros(self.num, dtype=np.int32)
  #迭代得到数据路径和标签一一对应
  for index, row in self.data.iterrows():
   self.path.append(os.path.join(dir_path,row['filename']))
   self.label[index] = row['label'] # 将数据全部读取出来
  #训练数据增强
  self.transform = transform
  #验证数据增强在这里没用
  self.target_transform = target_transform
 #最关键的部分,在这里使用前面的方法
 def __getitem__(self, index):
  img =Image.open(self.path[index]).convert('RGB')
  labels = self.label[index]
  #在这里做数据增强
  if self.transform is not None:
   img = self.transform(img) # 转化tensor类型
  return img, labels
 def __len__(self):
  return len(self.data)
#数据增强的具体内容
transform = transforms.Compose(
 [transforms.ToTensor(),
  transforms.Resize(150),
  transforms.CenterCrop(150),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
#加载数据
train_data = Mydataset(dir_path=dir_path,csv=csv_path, transform=transform)
trainloader = DataLoader(train_data, batch_size=16, shuffle=True, num_workers=0)
#迭代训练
for i_batch,batch_data in enumerate(trainloader):
 image,label=batch_data

补充:pytorch—定义自己的数据集及加载训练

笔记:pytorch Conv2d 的宽高公式理解,pytorch 使用自己的数据集并且加载训练

一、pypi 镜像使用帮助

pypi 镜像每 5 分钟同步一次。

临时使用

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple some-package

注意,simple 不能少, 是 https 而不是 http

设为默认

修改 ~/.config/pip/pip.conf (Linux), %APPDATA%\pip\pip.ini (Windows 10)$HOME/Library/Application Support/pip/pip.conf (macOS) (没有就创建一个), 修改 index-urltuna,例如

[global]
index-url = https://pypi.tuna.tsinghua.edu.cn/simple

pip 和 pip3 并存时,只需修改 ~/.pip/pip.conf。

二、pytorch Conv2d 的宽高公式理解

三、pytorch 使用自己的数据集并且加载训练

import os
import sys
import numpy as np
import cv2
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import time
import random
import csv
from PIL import Image
def createImgIndex(dataPath, ratio):
 '''
 读取目录下面的图片制作包含图片信息、图片label的train.txt和val.txt
 dataPath: 图片目录路径
 ratio: val占比
 return:label列表
 '''
 fileList = os.listdir(dataPath)
 random.shuffle(fileList)
 classList = [] # label列表
 # val 数据集制作
 with open('data/val_section1015.csv', 'w') as f:
  writer = csv.writer(f)
  for i in range(int(len(fileList)*ratio)):
   row = []
   if '.jpg' in fileList[i]:
    fileInfo = fileList[i].split('_')
    sectionName = fileInfo[0] + '_' + fileInfo[1] # 切面名+标准与否
    row.append(os.path.join(dataPath, fileList[i])) # 图片路径
    if sectionName not in classList:
     classList.append(sectionName)
    row.append(classList.index(sectionName))
    writer.writerow(row)
  f.close()
 # train 数据集制作
 with open('data/train_section1015.csv', 'w') as f:
  writer = csv.writer(f)
  for i in range(int(len(fileList) * ratio)+1, len(fileList)):
   row = []
   if '.jpg' in fileList[i]:
    fileInfo = fileList[i].split('_')
    sectionName = fileInfo[0] + '_' + fileInfo[1] # 切面名+标准与否
    row.append(os.path.join(dataPath, fileList[i])) # 图片路径
    if sectionName not in classList:
     classList.append(sectionName)
    row.append(classList.index(sectionName))
    writer.writerow(row)
  f.close()
 print(classList, len(classList))
 return classList
def default_loader(path):
 '''定义读取文件的格式'''
 return Image.open(path).resize((128, 128),Image.ANTIALIAS).convert('RGB')
class MyDataset(Dataset):
 '''Dataset类是读入数据集数据并且对读入的数据进行索引'''
 def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
  super(MyDataset, self).__init__() #对继承自父类的属性进行初始化
  fh = open(txt, 'r') #按照传入的路径和txt文本参数,以只读的方式打开这个文本
  reader = csv.reader(fh)
  imgs = []
  for row in reader:
   imgs.append((row[0], int(row[1]))) # (图片信息,lable)
  self.imgs = imgs
  self.transform = transform
  self.target_transform = target_transform
  self.loader = loader

 def __getitem__(self, index):
  '''用于按照索引读取每个元素的具体内容'''
  # fn是图片path #fn和label分别获得imgs[index]也即是刚才每行中row[0]和row[1]的信息
  fn, label = self.imgs[index]
  img = self.loader(fn)
  if self.transform is not None:
   img = self.transform(img) #数据标签转换为Tensor
  return img, label

 def __len__(self):
  '''返回数据集的长度'''
  return len(self.imgs)
class Model(nn.Module):
 def __init__(self, classNum=31):
  super(Model, self).__init__()
  # torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
  # torch.nn.MaxPool2d(kernel_size, stride, padding)
  # input 维度 [3, 128, 128]
  self.cnn = nn.Sequential(
   nn.Conv2d(3, 64, 3, 1, 1), # [64, 128, 128]
   nn.BatchNorm2d(64),
   nn.ReLU(),
   nn.MaxPool2d(2, 2, 0), # [64, 64, 64]
   nn.Conv2d(64, 128, 3, 1, 1), # [128, 64, 64]
   nn.BatchNorm2d(128),
   nn.ReLU(),
   nn.MaxPool2d(2, 2, 0), # [128, 32, 32]
   nn.Conv2d(128, 256, 3, 1, 1), # [256, 32, 32]
   nn.BatchNorm2d(256),
   nn.ReLU(),
   nn.MaxPool2d(2, 2, 0), # [256, 16, 16]
   nn.Conv2d(256, 512, 3, 1, 1), # [512, 16, 16]
   nn.BatchNorm2d(512),
   nn.ReLU(),
   nn.MaxPool2d(2, 2, 0), # [512, 8, 8]
   nn.Conv2d(512, 512, 3, 1, 1), # [512, 8, 8]
   nn.BatchNorm2d(512),
   nn.ReLU(),
   nn.MaxPool2d(2, 2, 0), # [512, 4, 4]
  )
  self.fc = nn.Sequential(
   nn.Linear(512 * 4 * 4, 1024),
   nn.ReLU(),
   nn.Linear(1024, 512),
   nn.ReLU(),
   nn.Linear(512, classNum)
  )
 def forward(self, x):
  out = self.cnn(x)
  out = out.view(out.size()[0], -1)
  return self.fc(out)
def train(train_set, train_loader, val_set, val_loader):
 model = Model()
 loss = nn.CrossEntropyLoss() # 因为是分类任务,所以loss function使用 CrossEntropyLoss
 optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # optimizer 使用 Adam
 num_epoch = 10
 # 开始训练
 for epoch in range(num_epoch):
  epoch_start_time = time.time()
  train_acc = 0.0
  train_loss = 0.0
  val_acc = 0.0
  val_loss = 0.0
  model.train() # train model会开放Dropout和BN
  for i, data in enumerate(train_loader):
   optimizer.zero_grad() # 用 optimizer 將 model 參數的 gradient 歸零
   train_pred = model(data[0]) # 利用 model 的 forward 函数返回预测结果
   batch_loss = loss(train_pred, data[1]) # 计算 loss
   batch_loss.backward() # tensor(item, grad_fn=<NllLossBackward>)
   optimizer.step() # 以 optimizer 用 gradient 更新参数
   train_acc += np.sum(np.argmax(train_pred.data.numpy(), axis=1) == data[1].numpy())
   train_loss += batch_loss.item()
  model.eval()
  with torch.no_grad(): # 不跟踪梯度
   for i, data in enumerate(val_loader):
    # data = [imgData, labelList]
    val_pred = model(data[0])
    batch_loss = loss(val_pred, data[1])
    val_acc += np.sum(np.argmax(val_pred.data.numpy(), axis=1) == data[1].numpy())
    val_loss += batch_loss.item()
   # 打印结果
   print('[%03d/%03d] %2.2f sec(s) Train Acc: %3.6f Loss: %3.6f | Val Acc: %3.6f loss: %3.6f' % \
     (epoch + 1, num_epoch, time.time() - epoch_start_time, \
     train_acc / train_set.__len__(), train_loss / train_set.__len__(), val_acc / val_set.__len__(),
     val_loss / val_set.__len__()))
if __name__ == '__main__':
 dirPath = '/data/Matt/QC_images/test0916' # 图片文件目录
 createImgIndex(dirPath, 0.2)    # 创建train.txt, val.txt
 root = os.getcwd() + '/data/'
 train_data = MyDataset(txt=root+'train_section1015.csv', transform=transforms.ToTensor())
 val_data = MyDataset(txt=root+'val_section1015.csv', transform=transforms.ToTensor())
 train_loader = DataLoader(dataset=train_data, batch_size=6, shuffle=True, num_workers = 4)
 val_loader = DataLoader(dataset=val_data, batch_size=6, shuffle=False, num_workers = 4)
 # 开始训练模型
 train(train_data, train_loader, val_data, val_loader)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。如有错误或未考虑完全的地方,望不吝赐教。

(0)

相关推荐

  • pytorch 数据加载性能对比分析

    传统方式需要10s,dat方式需要0.6s import os import time import torch import random from common.coco_dataset import COCODataset def gen_data(batch_size,data_path,target_path): os.makedirs(target_path,exist_ok=True) dataloader = torch.utils.data.DataLoader(COCODat

  • pytorch 自定义数据集加载方法

    pytorch 官网给出的例子中都是使用了已经定义好的特殊数据集接口来加载数据,而且其使用的数据都是官方给出的数据.如果我们有自己收集的数据集,如何用来训练网络呢?此时需要我们自己定义好数据处理接口.幸运的是pytroch给出了一个数据集接口类(torch.utils.data.Dataset),可以方便我们继承并实现自己的数据集接口. torch.utils.data torch的这个文件包含了一些关于数据集处理的类. class torch.utils.data.Dataset: 一个抽象类

  • pytorch加载自己的图像数据集实例

    之前学习深度学习算法,都是使用网上现成的数据集,而且都有相应的代码.到了自己开始写论文做实验,用到自己的图像数据集的时候,才发现无从下手 ,相信很多新手都会遇到这样的问题. 参考文章https://www.jb51.net/article/177613.htm 下面代码实现了从文件夹内读取所有图片,进行归一化和标准化操作并将图片转化为tensor.最后读取第一张图片并显示. # 数据处理 import os import torch from torch.utils import data fr

  • Pytorch 数据加载与数据预处理方式

    数据加载分为加载torchvision.datasets中的数据集以及加载自己使用的数据集两种情况. torchvision.datasets中的数据集 torchvision.datasets中自带MNIST,Imagenet-12,CIFAR等数据集,所有的数据集都是torch.utils.data.Dataset的子类,都包含 _ _ len _ (获取数据集长度)和 _ getItem _ _ (获取数据集中每一项)两个子方法. Dataset源码如上,可以看到其中包含了两个没有实现的子

  • Pytorch自己加载单通道图片用作数据集训练的实例

    pytorch 在torchvision包里面有很多的的打包好的数据集,例如minist,Imagenet-12,CIFAR10 和CIFAR100.在torchvision的dataset包里面,用的时候直接调用就行了.具体的调用格式可以去看文档(目前好像只有英文的).网上也有很多源代码. 不过,当我们想利用自己制作的数据集来训练网络模型时,就要有自己的方法了.pytorch在torchvision.dataset包里面封装过一个函数ImageFolder().这个函数功能很强大,只要你直接将

  • PyTorch加载自己的数据集实例详解

    数据预处理在解决深度学习问题的过程中,往往需要花费大量的时间和精力. 数据处理的质量对训练神经网络来说十分重要,良好的数据处理不仅会加速模型训练, 更会提高模型性能.为解决这一问题,PyTorch提供了几个高效便捷的工具, 以便使用者进行数据处理或增强等操作,同时可通过并行化加速数据加载. 数据集存放大致有以下两种方式: (1)所有数据集放在一个目录下,文件名上附有标签名,数据集存放格式如下: root/cat_dog/cat.01.jpg root/cat_dog/cat.02.jpg ...

  • pytorch加载语音类自定义数据集的方法教程

    前言 pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合 torch.utils.data.Dataset:所有继承他的子类都应该重写  __len()__  , __getitem()__ 这两个方法 __len()__ :返回数据集中数据的数量 __getitem()__ :返回支持下标索引方式获取的一个数据 torch.utils.data.DataLoad

  • pytorch从csv加载自定义数据模板的操作

    整理了一套模板,全注释了,这个难点终于克服了 from PIL import Image import pandas as pd import numpy as np import torchvision.transforms as transforms from torch.utils.data import Dataset, DataLoader import os #放文件的路径 dir_path= './97/train/' csv_path='./97/train.csv' class

  • Ext GridPanel加载完数据后进行操作示例代码

    比如load数据之后选定某些行数据. this相当于当前的GridPanel, idxs相当于你要选中的行号 复制代码 代码如下: this.store.on("load",function(store) { this.getSelectionModel().selectRows(idxs); //this.selectedRows = []; },this);

  • pytorch加载自定义网络权重的实现

    在将自定义的网络权重加载到网络中时,报错: AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead. 我们一步一步分析. 模型网络权重保存额代码是:torch.sa

  • ASP.NET仿新浪微博下拉加载更多数据瀑布流效果

    闲来无事,琢磨着写点东西.貌似页面下拉加载数据,瀑布流的效果很火,各个网站都能见到各式各样的展示效果,原理大同小异.于是乎,决定自己写一写这个效果,希望能给比我还菜的菜鸟们一点参考价值. 在开始之前,先把实现的基本原理说一下.当夜幕下拉到底部的时候,js可以判断滚动条的位置,到达底部触发js方法,执行jquery的ajax方法,向后台一般处理程序夜幕ashx文件请求数据源,得到json格式的数据源.然后,遍历json数据源,拼接一个li标签,再填充到页面上去. 首先,我们来做个简单的html页面

  • vue加载自定义的js文件方法

    在做项目中需要自定义弹出框.就自己写了一个. 效果图 遇见的问题 怎么加载自定义的js文件 vue-插件这必须要看.然后就是自己写了. export default{ install(Vue){ var tpl; // 弹出框 Vue.prototype.showAlter = (title,msg) =>{ var alterTpl = Vue.extend({ // 1.创建构造器,定义好提示信息的模板 template: '<div id="bg">' + '&

  • 详解Angular结合zTree异步加载节点数据

    1 前提准备 1.1 新建一个angular4项目 参考://www.jb51.net/article/119668.htm 1.2 去zTree官网下载zTree zTree官网:点击前往 三少使用的版本:点击前往 1.3 参考博客 //www.jb51.net/article/133284.htm 2 编程步骤 从打印出zTree对象可以看出,zTree对象利用init方法来实现zTree结构:init方法接收三个参数 参数1:一个ul标签的DOM节点对象 参数2:基本配置对象 参数3:标题

  • Vue向下滚动加载更多数据scroll案例详解

    vue-infinite-scroll 安装 npm install vue-infinite-scroll --save 尽管官方也推荐了几种载入方式,但"最vue"的方式肯定是在main.js中加入 import infiniteScroll from 'vue-infinite-scroll' Vue.use(infiniteScroll) 实现范例 官方给的代码范例是假设你在根组件写代码,实际上我们肯定是在子组件里写代码,所以代码也需要略作修改,下方只列有用的代码片段: <

  • Vue页面加载完成后如何自动加载自定义函数

    目录 页面加载完成后自动加载自定义函数 created mounted vue之自执行函数 页面加载完成后自动加载自定义函数 created 在模板渲染成html前调用,即通常初始化某些属性值,然后再渲染成视图. methods: {             indexs:function(){                 this.$http.post('{:url("Index/fun")}')                     .then(function(res){

  • 解决Ajax加载JSon数据中文乱码问题

    一.问题描述 使用zTree的异步刷新父级菜单时,服务器返回中文乱码,但项目中使用了SpringMvc,已经对中文乱码处理,为什么还会出现呢? 此处为的异步请求的配置: Java代码 async: { enable: true, url: basePath + '/sysMenu/listSysMenu', autoParam: ["id=parentId"] } SpringMvc中文字符处理: Java代码 <mvc:annotation-driven> <mvc

随机推荐