PyTorch如何创建自己的数据集
目录
- PyTorch创建自己的数据集
- pytorch常用数据集的使用
PyTorch创建自己的数据集
图片文件在同一的文件夹下
思路是继承 torch.utils.data.Dataset,并重点重写其 __getitem__方法,示例代码如下:
class ImageFolder(Dataset): def __init__(self, folder_path): self.files = sorted(glob.glob('%s/*.*' % folder_path)) def __getitem__(self, index): path = self.files[index % len(self.files)] img = np.array(Image.open(path)) h, w, c = img.shape pad = ((40, 40), (4, 4), (0, 0)) # img = np.pad(img, pad, 'constant', constant_values=0) / 255 img = np.pad(img, pad, mode='edge') / 255.0 img = torch.from_numpy(img).float() patches = np.reshape(img, (3, 10, 128, 11, 128)) patches = np.transpose(patches, (0, 1, 3, 2, 4)) return img, patches, path def __len__(self): return len(self.files)
图片文件在不同的文件夹下
比如我们有数据如下:
─── data
├── train
│ ├── 0.jpg
│ └── 1.jpg
├── test
│ ├── 0.jpg
│ └── 1.jpg
└── val
├── 1.jpg
└── 2.jpg
此时我们只需要将以上代码稍作修改即可,修改的代码如下:
self.files = sorted(glob.glob('%s/**/*.*' % folder_path, recursive=True))
其他代码不变。
pytorch常用数据集的使用
对于pytorch数据集的使用,示例代码如下:
from torch.utils.tensorboard import SummaryWriter from torchvision.transforms import Compose from torchvision import transforms import torchvision import ssl ssl._create_default_https_context = ssl._create_unverified_context dataset_transform = Compose([transforms.ToTensor()]) # 关于官方数据集的使用还是关键要看pytorch的官方文档 train_set = torchvision.datasets.CIFAR10(root="./CIFAR10",train=True,transform=dataset_transform,download=True) test_set = torchvision.datasets.CIFAR10(root="./CIFAR10",train=False,transform=dataset_transform,download=True) # 查看测试数据集中的第一个数据 # print(test_set[0]) # 查看测试数据集中的分类情况 # print(test_set.classes) # # 取出第一个数据中的图片(img)和分类结果(target) # img,target = test_set[0] # 查看图片数据的类型 # print(img) # print(target) # 输出类别 # print(test_set.classes[target]) # 查看图片 # img.show() # 使用tensorboard显示tensor数据类型的图片 writer = SummaryWriter("logs") for i in range(10): # 取出数据中的图片(img)和分类结果(target) img,target = test_set[i] writer.add_image("test_set",img,i) writer.close()
上述代码运行结果在tensorboard可视化:
代码
train_set = torchvision.datasets.CIFAR10(root="./CIFAR10",train=True,transform=dataset_transform,download=True)
常用参数讲解
root
:根目录,存放数据集的位置train
:若为True,则划分为训练数据集,若为False,则划分为测试数据集transform
:指定输入数据集处理方式download
:若为True,则会将数据集下载到root指定的目录下,否则不会下载
官方文档对参数的解释:
root (string) – Root directory of dataset where directory cifar-10-batches-py exists or will be saved to if download is set to True.
train (bool, optional) – If True, creates dataset from training set, otherwise creates from test set.
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
注意:
- 关于官方数据集的使用还是关键要看pytorch的官方文档
- 下载数据集的细节之处:知道下载链接(下载链接可以在源码中查看)之后可以不用使用代码下载了,使用迅雷来下载可能会更快。
- 要学会使用Pycharm中的ctrl+p和ctrl+alt这两个快捷键
- pytorch官网
- pytorch官方数据集(下载数据集方法)
以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。