深度学习入门之Pytorch 数据增强的实现

数据增强

卷积神经网络非常容易出现过拟合的问题,而数据增强的方法是对抗过拟合问题的一个重要方法。

2012 年 AlexNet 在 ImageNet 上大获全胜,图片增强方法功不可没,因为有了图片增强,使得训练的数据集比实际数据集多了很多'新'样本,减少了过拟合的问题,下面我们来具体解释一下。

常用的数据增强方法

常用的数据增强方法如下:
1.对图片进行一定比例缩放
2.对图片进行随机位置的截取
3.对图片进行随机的水平和竖直翻转
4.对图片进行随机角度的旋转
5.对图片进行亮度、对比度和颜色的随机变化

这些方法 pytorch 都已经为我们内置在了 torchvision 里面,我们在安装 pytorch 的时候也安装了 torchvision,下面我们来依次展示一下这些数据增强方法。

import sys
sys.path.append('..')

from PIL import Image
from torchvision import transforms as tfs

# 读入一张图片
im = Image.open('./cat.png')
im

随机比例放缩

随机比例缩放主要使用的是 torchvision.transforms.Resize() 这个函数,第一个参数可以是一个整数,那么图片会保存现在的宽和高的比例,并将更短的边缩放到这个整数的大小,第一个参数也可以是一个 tuple,那么图片会直接把宽和高缩放到这个大小;第二个参数表示放缩图片使用的方法,比如最邻近法,或者双线性差值等,一般双线性差值能够保留图片更多的信息,所以 pytorch 默认使用的是双线性差值,你可以手动去改这个参数,更多的信息可以看看文档

# 比例缩放
print('before scale, shape: {}'.format(im.size))
new_im = tfs.Resize((100, 200))(im)
print('after scale, shape: {}'.format(new_im.size))
new_im

随机位置截取

随机位置截取能够提取出图片中局部的信息,使得网络接受的输入具有多尺度的特征,所以能够有较好的效果。在 torchvision 中主要有下面两种方式,一个是 torchvision.transforms.RandomCrop(),传入的参数就是截取出的图片的长和宽,对图片在随机位置进行截取;第二个是 torchvision.transforms.CenterCrop(),同样传入介曲初的图片的大小作为参数,会在图片的中心进行截取

# 随机裁剪出 100 x 100 的区域
random_im1 = tfs.RandomCrop(100)(im)
random_im1

# 中心裁剪出 100 x 100 的区域
center_im = tfs.CenterCrop(100)(im)
center_im

随机的水平和竖直方向翻转

对于上面这一张猫的图片,如果我们将它翻转一下,它仍然是一张猫,但是图片就有了更多的多样性,所以随机翻转也是一种非常有效的手段。在 torchvision 中,随机翻转使用的是 torchvision.transforms.RandomHorizontalFlip()torchvision.transforms.RandomVerticalFlip()

# 随机水平翻转
h_filp = tfs.RandomHorizontalFlip()(im)
h_filp

# 随机竖直翻转
v_flip = tfs.RandomVerticalFlip()(im)
v_flip

随机角度旋转

一些角度的旋转仍然是非常有用的数据增强方式,在 torchvision 中,使用 torchvision.transforms.RandomRotation() 来实现,其中第一个参数就是随机旋转的角度,比如填入 10,那么每次图片就会在 -10 ~ 10 度之间随机旋转

rot_im = tfs.RandomRotation(45)(im)
rot_im

亮度、对比度和颜色的变化

除了形状变化外,颜色变化又是另外一种增强方式,其中可以设置亮度变化,对比度变化和颜色变化等,在 torchvision 中主要使用 torchvision.transforms.ColorJitter() 来实现的,第一个参数就是亮度的比例,第二个是对比度,第三个是饱和度,第四个是颜色

# 亮度
bright_im = tfs.ColorJitter(brightness=1)(im) # 随机从 0 ~ 2 之间亮度变化,1 表示原图
bright_im

# 对比度
contrast_im = tfs.ColorJitter(contrast=1)(im) # 随机从 0 ~ 2 之间对比度变化,1 表示原图
contrast_im

# 颜色
color_im = tfs.ColorJitter(hue=0.5)(im) # 随机从 -0.5 ~ 0.5 之间对颜色变化
color_im

上面我们讲了这么图片增强的方法,其实这些方法都不是孤立起来用的,可以联合起来用,比如先做随机翻转,然后随机截取,再做对比度增强等等,torchvision 里面有个非常方便的函数能够将这些变化合起来,就是 torchvision.transforms.Compose(),下面我们举个例子

im_aug = tfs.Compose([
  tfs.Resize(120),
  tfs.RandomHorizontalFlip(),
  tfs.RandomCrop(96),
  tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5)
])
import matplotlib.pyplot as plt
%matplotlib inline
nrows = 3
ncols = 3
figsize = (8, 8)
_, figs = plt.subplots(nrows, ncols, figsize=figsize)
for i in range(nrows):
  for j in range(ncols):
    figs[i][j].imshow(im_aug(im))
    figs[i][j].axes.get_xaxis().set_visible(False)
    figs[i][j].axes.get_yaxis().set_visible(False)
plt.show()

可以看到每次做完增强之后的图片都有一些变化,所以这就是我们前面讲的,增加了一些'新'数据
下面我们使用图像增强进行训练网络,看看具体的提升究竟在什么地方,使用 ResNet 进行训练

使用数据增强

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.datasets import CIFAR10
from utils import train, resnet
from torchvision import transforms as tfs
# 使用数据增强
def train_tf(x):
  im_aug = tfs.Compose([
    tfs.Resize(120),
    tfs.RandomHorizontalFlip(),
    tfs.RandomCrop(96),
    tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  ])
  x = im_aug(x)
  return x

def test_tf(x):
  im_aug = tfs.Compose([
    tfs.Resize(96),
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  ])
  x = im_aug(x)
  return x

train_set = CIFAR10('./data', train=True, transform=train_tf)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = CIFAR10('./data', train=False, transform=test_tf)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
train(net, train_data, test_data, 10, optimizer, criterion)

不使用数据增强

# 不使用数据增强
def data_tf(x):
  im_aug = tfs.Compose([
    tfs.Resize(96),
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  ])
  x = im_aug(x)
  return x

train_set = CIFAR10('./data', train=True, transform=data_tf)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = CIFAR10('./data', train=False, transform=data_tf)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
train(net, train_data, test_data, 10, optimizer, criterion)

从上面可以看出,对于训练集,不做数据增强跑 10 次,准确率已经到了 95%,而使用了数据增强,跑 10 次准确率只有 75%,说明数据增强之后变得更难了。

而对于测试集,使用数据增强进行训练的时候,准确率会比不使用更高,因为数据增强提高了模型应对于更多的不同数据集的泛化能力,所以有更好的效果。

以上就是深度学习入门之Pytorch 数据增强的实现的详细内容,更多关于Pytorch 数据增强的资料请关注我们其它相关文章!

(0)

相关推荐

  • pytorch cnn 识别手写的字实现自建图片数据

    本文主要介绍了pytorch cnn 识别手写的字实现自建图片数据,分享给大家,具体如下: # library # standard library import os # third-party library import torch import torch.nn as nn from torch.autograd import Variable from torch.utils.data import Dataset, DataLoader import torchvision impo

  • pytorch 自定义数据集加载方法

    pytorch 官网给出的例子中都是使用了已经定义好的特殊数据集接口来加载数据,而且其使用的数据都是官方给出的数据.如果我们有自己收集的数据集,如何用来训练网络呢?此时需要我们自己定义好数据处理接口.幸运的是pytroch给出了一个数据集接口类(torch.utils.data.Dataset),可以方便我们继承并实现自己的数据集接口. torch.utils.data torch的这个文件包含了一些关于数据集处理的类. class torch.utils.data.Dataset: 一个抽象类

  • pytorch 把MNIST数据集转换成图片和txt的方法

    本文介绍了pytorch 把MNIST数据集转换成图片和txt的方法,分享给大家,具体如下: 1.下载Mnist 数据集 import os # third-party library import torch import torch.nn as nn from torch.autograd import Variable import torch.utils.data as Data import torchvision import matplotlib.pyplot as plt # t

  • Pytorch 数据加载与数据预处理方式

    数据加载分为加载torchvision.datasets中的数据集以及加载自己使用的数据集两种情况. torchvision.datasets中的数据集 torchvision.datasets中自带MNIST,Imagenet-12,CIFAR等数据集,所有的数据集都是torch.utils.data.Dataset的子类,都包含 _ _ len _ (获取数据集长度)和 _ getItem _ _ (获取数据集中每一项)两个子方法. Dataset源码如上,可以看到其中包含了两个没有实现的子

  • pytorch 调整某一维度数据顺序的方法

    在pytorch中,Tensor是以引用的形式存在的,故而并不能直接像python交换数据那样 a = torch.Tensor(3,4) a[0],a[1] = a[1],a[0] # 这会导致a的结果为a=(a[1],a[1],a[2]) # 而非预期的(a[1],a[0],a[2]) 这是因为引用赋值导致的,在交换过程,如下所示,当b的值赋值与a的时候,因为tmp指针与a是同一变量的不同名,故而tmp的内容也会变为b. # 交换a,b a,b = b,a # 等价于 tmp = a a =

  • PyTorch读取Cifar数据集并显示图片的实例讲解

    首先了解一下需要的几个类所在的package from torchvision import transforms, datasets as ds from torch.utils.data import DataLoader import matplotlib.pyplot as plt import numpy as np #transform = transforms.Compose是把一系列图片操作组合起来,比如减去像素均值等. #DataLoader读入的数据类型是PIL.Image

  • 深度学习入门之Pytorch 数据增强的实现

    数据增强 卷积神经网络非常容易出现过拟合的问题,而数据增强的方法是对抗过拟合问题的一个重要方法. 2012 年 AlexNet 在 ImageNet 上大获全胜,图片增强方法功不可没,因为有了图片增强,使得训练的数据集比实际数据集多了很多'新'样本,减少了过拟合的问题,下面我们来具体解释一下. 常用的数据增强方法 常用的数据增强方法如下: 1.对图片进行一定比例缩放 2.对图片进行随机位置的截取 3.对图片进行随机的水平和竖直翻转 4.对图片进行随机角度的旋转 5.对图片进行亮度.对比度和颜色的

  • PyTorch的深度学习入门之PyTorch安装和配置

    前言 深度神经网络是一种目前被广泛使用的工具,可以用于图像识别.分类,物体检测,机器翻译等等.深度学习(DeepLearning)是一种学习神经网络各种参数的方法.因此,我们将要介绍的深度学习,指的是构建神经网络结构,并且运用各种深度学习算法训练网络参数,进而解决各种任务.本文从PyTorch环境配置开始.PyTorch是一种Python接口的深度学习框架,使用灵活,学习方便.还有其他主流的深度学习框架,例如Caffe,TensorFlow,CNTK等等,各有千秋.笔者认为,初期学习还是选择一种

  • python人工智能深度学习入门逻辑回归限制

    目录 1.逻辑回归的限制 2.深度学习的引入 3.深度学习的计算方式 4.神经网络的损失函数 1.逻辑回归的限制 逻辑回归分类的时候,是把线性的函数输入进sigmoid函数进行转换,后进行分类,会在图上画出一条分类的直线,但像下图这种情况,无论怎么画,一条直线都不可能将其完全分开. 但假如我们可以对输入的特征进行一个转换,便有可能完美分类.比如: 创造一个新的特征x1:到(0,0)的距离,另一个x2:到(1,1)的距离.这样可以计算出四个点所对应的新特征,画到坐标系上如以下右图所示.这样转换之后

  • Python-OpenCV深度学习入门示例详解

    目录 0. 前言 1. 计算机视觉中的深度学习简介 1.1 深度学习的特点 1.2 深度学习大爆发 2. 用于图像分类的深度学习简介 3. 用于目标检测的深度学习简介 4. 深度学习框架 keras 介绍与使用 4.1 keras 库简介与安装 4.2 使用 keras 实现线性回归模型 4.3 使用 keras 进行手写数字识别 小结 0. 前言 深度学习已经成为机器学习中最受欢迎和发展最快的领域.自 2012 年深度学习性能超越机器学习等传统方法以来,深度学习架构开始快速应用于包括计算机视觉

  • Python深度学习之使用Pytorch搭建ShuffleNetv2

    一.model.py 1.1 Channel Shuffle def channel_shuffle(x: Tensor, groups: int) -> Tensor: batch_size, num_channels, height, width = x.size() channels_per_group = num_channels // groups # reshape # [batch_size, num_channels, height, width] -> [batch_size

  • PyTorch的深度学习入门教程之构建神经网络

    前言 本文参考PyTorch官网的教程,分为五个基本模块来介绍PyTorch.为了避免文章过长,这五个模块分别在五篇博文中介绍. Part3:使用PyTorch构建一个神经网络 神经网络可以使用touch.nn来构建.nn依赖于autograd来定义模型,并且对其求导.一个nn.Module包含网络的层(layers),同时forward(input)可以返回output. 这是一个简单的前馈网络.它接受输入,然后一层一层向前传播,最后输出一个结果. 训练神经网络的典型步骤如下: (1)  定义

  • 理解深度学习之深度学习简介

    机器学习 在吴恩达老师的课程中,有过对机器学习的定义: ML:<P T E> P即performance,T即Task,E即Experience,机器学习是对一个Task,根据Experience,去提升Performance: 在机器学习中,神经网络的地位越来越重要,实践发现,非线性的激活函数有助于神经网络拟合分布,效果明显优于线性分类器: y=Wx+b 常用激活函数有ReLU,sigmoid,tanh: sigmoid将值映射到(0,1): tanh会将输入映射到(-1,1)区间: #激活

  • Python编程深度学习计算库之numpy

    NumPy是python下的计算库,被非常广泛地应用,尤其是近来的深度学习的推广.在这篇文章中,将会介绍使用numpy进行一些最为基础的计算. NumPy vs SciPy NumPy和SciPy都可以进行运算,主要区别如下 最近比较热门的深度学习,比如在神经网络的算法,多维数组的使用是一个极为重要的场景.如果你熟悉tensorflow中的tensor的概念,你会非常清晰numpy的作用.所以熟悉Numpy可以说是使用python进行深度学习入门的一个基础知识. 安装 liumiaocn:tmp

  • Python深度学习albumentations数据增强库

    数据增强的必要性 深度学习在最近十年得以风靡得益于计算机算力的提高以及数据资源获取的难度下降.一个好的深度模型往往需要大量具有label的数据,使得模型能够很好的学习这种数据的分布.而给数据打标签往往是一件耗时耗力的工作. 拿cv里的经典任务为例,classification需要人准确识别物品类别或者生物种类,object detection需要人工画出bounding box, 确定其坐标,semantic segmentation甚至需要在像素级别进行标签标注.对于一些专业领域的图像标注,依

  • pytorch深度神经网络入门准备自己的图片数据

    目录 正文 一.所有图片放在一个文件夹内 二.不同类别的图片放在不同的文件夹内 正文 图片数据一般有两种情况: 1.所有图片放在一个文件夹内,另外有一个txt文件显示标签. 2.不同类别的图片放在不同的文件夹内,文件夹就是图片的类别. 针对这两种不同的情况,数据集的准备也不相同,第一种情况可以自定义一个Dataset,第二种情况直接调用torchvision.datasets.ImageFolder来处理.下面分别进行说明: 一.所有图片放在一个文件夹内 这里以mnist数据集的10000个te

随机推荐