pytorch sampler对数据进行采样的实现

PyTorch中还单独提供了一个sampler模块,用来对数据进行采样。常用的有随机采样器:RandomSampler,当dataloader的shuffle参数为True时,系统会自动调用这个采样器,实现打乱数据。默认的是采用SequentialSampler,它会按顺序一个一个进行采样。这里介绍另外一个很有用的采样方法: WeightedRandomSampler,它会根据每个样本的权重选取数据,在样本比例不均衡的问题中,可用它来进行重采样。

构建WeightedRandomSampler时需提供两个参数:每个样本的权重weights、共选取的样本总数num_samples,以及一个可选参数replacement。权重越大的样本被选中的概率越大,待选取的样本数目一般小于全部的样本数目。replacement用于指定是否可以重复选取某一个样本,默认为True,即允许在一个epoch中重复采样某一个数据。如果设为False,则当某一类的样本被全部选取完,但其样本数目仍未达到num_samples时,sampler将不会再从该类中选择数据,此时可能导致weights参数失效。

下面举例说明。

from dataSet import *
dataset = DogCat('data/dogcat/', transform=transform)

from torch.utils.data import DataLoader
# 狗的图片被取出的概率是猫的概率的两倍
# 两类图片被取出的概率与weights的绝对大小无关,只和比值有关
weights = [2 if label == 1 else 1 for data, label in dataset]

print(weights)

from torch.utils.data.sampler import WeightedRandomSampler
sampler = WeightedRandomSampler(weights,\
                num_samples=9,\
                replacement=True)
dataloader = DataLoader(dataset,
            batch_size=3,
            sampler=sampler)
for datas, labels in dataloader:
  print(labels.tolist())

输出:

[2, 2, 1, 1, 2, 1, 1, 2]
[1, 1, 0]
[1, 0, 0]
[0, 0, 1]

github 地址:

https://github.com/WebLearning17/CommonTool

以上这篇pytorch sampler对数据进行采样的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • Pytorch之contiguous的用法

    contiguous tensor变量调用contiguous()函数会使tensor变量在内存中的存储变得连续. contiguous():view只能用在contiguous的variable上.如果在view之前用了transpose, permute等,需要用contiguous()来返回一个contiguous copy. 一种可能的解释是: 有些tensor并不是占用一整块内存,而是由不同的数据块组成,而tensor的view()操作依赖于内存是整块的,这时只需要执行contiguo

  • 使用pytorch实现可视化中间层的结果

    摘要 一直比较想知道图片经过卷积之后中间层的结果,于是使用pytorch写了一个脚本查看,先看效果 这是原图,随便从网上下载的一张大概224*224大小的图片,如下 网络介绍 我们使用的VGG16,包含RULE层总共有30层可以可视化的结果,我们把这30层分别保存在30个文件夹中,每个文件中根据特征的大小保存了64~128张图片 结果如下: 原图大小为224224,经过第一层后大小为64224*224,下面是第一层可视化的结果,总共有64张这样的图片: 下面看看第六层的结果 这层的输出大小是 1

  • pytorch torch.expand和torch.repeat的区别详解

    1.torch.expand 函数返回张量在某一个维度扩展之后的张量,就是将张量广播到新形状.函数对返回的张量不会分配新内存,即在原始张量上返回只读视图,返回的张量内存是不连续的.类似于numpy中的broadcast_to函数的作用.如果希望张量内存连续,可以调用contiguous函数. 例子: import torch x = torch.tensor([1, 2, 3, 4]) xnew = x.expand(2, 4) print(xnew) 输出: tensor([[1, 2, 3,

  • pytorch sampler对数据进行采样的实现

    PyTorch中还单独提供了一个sampler模块,用来对数据进行采样.常用的有随机采样器:RandomSampler,当dataloader的shuffle参数为True时,系统会自动调用这个采样器,实现打乱数据.默认的是采用SequentialSampler,它会按顺序一个一个进行采样.这里介绍另外一个很有用的采样方法: WeightedRandomSampler,它会根据每个样本的权重选取数据,在样本比例不均衡的问题中,可用它来进行重采样. 构建WeightedRandomSampler时

  • PyTorch读取Cifar数据集并显示图片的实例讲解

    首先了解一下需要的几个类所在的package from torchvision import transforms, datasets as ds from torch.utils.data import DataLoader import matplotlib.pyplot as plt import numpy as np #transform = transforms.Compose是把一系列图片操作组合起来,比如减去像素均值等. #DataLoader读入的数据类型是PIL.Image

  • Pytorch 使用 nii数据做输入数据的操作

    使用pix2pix-gan做医学图像合成的时候,如果把nii数据转成png格式会损失很多信息,以为png格式图像的灰度值有256阶,因此直接使用nii的医学图像做输入会更好一点. 但是Pythorch中的Dataloader是不能直接读取nii图像的,因此加一个CreateNiiDataset的类. 先来了解一下pytorch中读取数据的主要途径--Dataset类.在自己构建数据层时都要基于这个类,类似于C++中的虚基类. 自己构建的数据层包含三个部分 class Dataset(object

  • pytorch读取图像数据转成opencv格式实例

    pytorch读取图像数据转成opencv格式方法:先转成numpy通用的格式,再将其转换成opencv格式. pytorch读取的数据使用loaddata这类函数实现.pytorch网络输入图像的格式为(C, H, W),就是(通道数,高,宽)而numpy中图像的格式为(H,W,C). 那就将其通道调换一下.用到函数transpose. 转换方法如下 例如A 的格式为(c,h,w) 那么经过 A = A.transpose(1,2,0) 后就变成了(h,w,c)了 然后用语句 B= cv2.c

  • pytorch 把图片数据转化成tensor的操作

    摘要: 在图像识别当中,一般步骤是先读取图片,然后把图片数据转化成tensor格式,再输送到网络中去.本文将介绍如何把图片转换成tensor. 一.数据转换 把图片转成成torch的tensor数据,一般采用函数:torchvision.transforms.通过一个例子说明,先用opencv读取一张图片,然后在转换:注意一点是:opencv储存图片的格式和torch的储存方式不一样,opencv储存图片格式是(H,W,C),而torch储存的格式是(C,H,W). import torchvi

  • Pytorch数据读取与预处理该如何实现

    在炼丹时,数据的读取与预处理是关键一步.不同的模型所需要的数据以及预处理方式各不相同,如果每个轮子都我们自己写的话,是很浪费时间和精力的.Pytorch帮我们实现了方便的数据读取与预处理方法,下面记录两个DEMO,便于加快以后的代码效率. 根据数据是否一次性读取完,将DEMO分为: 1.串行式读取.也就是一次性读取完所有需要的数据到内存,模型训练时不会再访问外存.通常用在内存足够的情况下使用,速度更快. 2.并行式读取.也就是边训练边读取数据.通常用在内存不够的情况下使用,会占用计算资源,如果分

  • Pytorch上下采样函数--interpolate用法

    最近用到了上采样下采样操作,pytorch中使用interpolate可以很轻松的完成 def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None): r""" 根据给定 size 或 scale_factor,上采样或下采样输入数据input. 当前支持 temporal, spatial 和 volumetric 输入数据的上采样,其shape 分别为:3-

  • pytorch中dataloader 的sampler 参数详解

    目录 1. dataloader() 初始化函数 2. shuffle 与sample 之间的关系 3. sample 的定义方法 3.1 sampler 参数的使用 4. batch 生成过程 1. dataloader() 初始化函数 def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_mem

  • 解决pytorch load huge dataset(大数据加载)

    问题 最近用pytorch做实验时,遇到加载大量数据的问题.实验数据大小在400Gb,而本身机器的memory只有256Gb,显然无法将数据一次全部load到memory. 解决方法 首先自定义一个MyDataset继承torch.utils.data.Dataset,然后将MyDataset的对象feed in torch.utils.data.DataLoader()即可. MyDataset在__init__中声明一个文件对象,然后在__getitem__中缓慢读取数据,这样就不会一次把所

  • 如何使用PyTorch实现自由的数据读取

    目录 前言 PyTorch数据读入函数介绍 ImageFolder Dataset DataLoader 问题来源 自定义数据读入的举例实现 总结 前言 很多前人曾说过,深度学习好比炼丹,框架就是丹炉,网络结构及算法就是单方,而数据集则是原材料,为了能够炼好丹,首先需要一个使用称手的丹炉,同时也要有好的单方和原材料,最后就需要炼丹师们有着足够的经验和技巧掌握火候和时机,这样方能炼出绝世好丹. 对于刚刚进入炼丹行业的炼丹师,网上都有一些前人总结的炼丹技巧,同时也有很多炼丹师的心路历程以及丹师对整个

随机推荐