pytorch 图像中的数据预处理和批标准化实例

目前数据预处理最常见的方法就是中心化和标准化。

中心化相当于修正数据的中心位置,实现方法非常简单,就是在每个特征维度上减去对应的均值,最后得到 0 均值的特征。

标准化也非常简单,在数据变成 0 均值之后,为了使得不同的特征维度有着相同的规模,可以除以标准差近似为一个标准正态分布,也可以依据最大值和最小值将其转化为 -1 ~ 1 之间

批标准化:BN

在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好。但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相关,且不再满足一个标准的 N(0, 1) 的分布,甚至输出的中心已经发生了偏移,这对于模型的训练,特别是深层的模型训练非常的困难。

所以在 2015 年一篇论文提出了这个方法,批标准化,简而言之,就是对于每一层网络的输出,对其做一个归一化,使其服从标准的正态分布,这样后一层网络的输入也是一个标准的正态分布,所以能够比较好的进行训练,加快收敛速度。

batch normalization 的实现非常简单,接下来写一下对应的python代码:

import sys
sys.path.append('..')

import torch

def simple_batch_norm_1d(x, gamma, beta):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)

x = torch.arange(15).view(5, 3)
gamma = torch.ones(x.shape[1])
beta = torch.zeros(x.shape[1])
print('before bn: ')
print(x)
y = simple_batch_norm_1d(x, gamma, beta)
print('after bn: ')
print(y)

测试的时候该使用批标准化吗?

答案是肯定的,因为训练的时候使用了,而测试的时候不使用肯定会导致结果出现偏差,但是测试的时候如果只有一个数据集,那么均值不就是这个值,方差为 0 吗?这显然是随机的,所以测试的时候不能用测试的数据集去算均值和方差,而是用训练的时候算出的移动平均均值和方差去代替

下面我们实现以下能够区分训练状态和测试状态的批标准化方法

def batch_norm_1d(x, gamma, beta, is_training, moving_mean, moving_var, moving_momentum=0.1):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  if is_training:
    x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
    moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean
    moving_var[:] = moving_momentum * moving_var + (1. - moving_momentum) * x_var
  else:
    x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)

下面我们在卷积网络下试用一下批标准化看看效果

def data_tf(x):
  x = np.array(x, dtype='float32') / 255
  x = (x - 0.5) / 0.5 # 数据预处理,标准化
  x = torch.from_numpy(x)
  x = x.unsqueeze(0)
  return x

train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True) # 重新载入数据集,申明定义的数据变换
test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True)
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)
# 使用批标准化
class conv_bn_net(nn.Module):
  def __init__(self):
    super(conv_bn_net, self).__init__()
    self.stage1 = nn.Sequential(
      nn.Conv2d(1, 6, 3, padding=1),
      nn.BatchNorm2d(6),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2),
      nn.Conv2d(6, 16, 5),
      nn.BatchNorm2d(16),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2)
    )

    self.classfy = nn.Linear(400, 10)
  def forward(self, x):
    x = self.stage1(x)
    x = x.view(x.shape[0], -1)
    x = self.classfy(x)
    return x

net = conv_bn_net()
optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用随机梯度下降,学习率 0.1

train(net, train_data, test_data, 5, optimizer, criterion)

以上这篇pytorch 图像中的数据预处理和批标准化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • Pytorch 数据加载与数据预处理方式

    数据加载分为加载torchvision.datasets中的数据集以及加载自己使用的数据集两种情况. torchvision.datasets中的数据集 torchvision.datasets中自带MNIST,Imagenet-12,CIFAR等数据集,所有的数据集都是torch.utils.data.Dataset的子类,都包含 _ _ len _ (获取数据集长度)和 _ getItem _ _ (获取数据集中每一项)两个子方法. Dataset源码如上,可以看到其中包含了两个没有实现的子

  • pytorch 实现将自己的图片数据处理成可以训练的图片类型

    为了使用自己的图像数据,需要仿照pytorch数据输入创建新的类,其中数据格式为numpy.ndarray. 将自己的图片保存到numpy.ndarray中,然后创建类 from torch.utils.data import Dataset import numpy as np class Dataset(Dataset): def __init__(self, path_img, path_target, transforms=None): self.train = path_img sel

  • 关于Pytorch的MNIST数据集的预处理详解

    关于Pytorch的MNIST数据集的预处理详解 MNIST的准确率达到99.7% 用于MNIST的卷积神经网络(CNN)的实现,具有各种技术,例如数据增强,丢失,伪随机化等. 操作系统:ubuntu18.04 显卡:GTX1080ti python版本:2.7(3.7) 网络架构 具有4层的CNN具有以下架构. 输入层:784个节点(MNIST图像大小) 第一卷积层:5x5x32 第一个最大池层 第二卷积层:5x5x64 第二个最大池层 第三个完全连接层:1024个节点 输出层:10个节点(M

  • pytorch中的自定义数据处理详解

    pytorch在数据中采用Dataset的数据保存方式,需要继承data.Dataset类,如果需要自己处理数据的话,需要实现两个基本方法. :.getitem:返回一条数据或者一个样本,obj[index] = obj.getitem(index). :.len:返回样本的数量 . len(obj) = obj.len(). Dataset 在data里,调用的时候使用 from torch.utils import data import os from PIL import Image 数

  • pytorch 图像中的数据预处理和批标准化实例

    目前数据预处理最常见的方法就是中心化和标准化. 中心化相当于修正数据的中心位置,实现方法非常简单,就是在每个特征维度上减去对应的均值,最后得到 0 均值的特征. 标准化也非常简单,在数据变成 0 均值之后,为了使得不同的特征维度有着相同的规模,可以除以标准差近似为一个标准正态分布,也可以依据最大值和最小值将其转化为 -1 ~ 1 之间 批标准化:BN 在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好.但是对于很深的网路结构,网路的非线性层会使得输出的结

  • 详解如何使用Python隐藏图像中的数据

    目录 编码 例子 解码 程序执行 局限性 参考 隐写术是在任何文件中隐藏秘密数据的艺术. 秘密数据可以是任何格式的数据,如文本甚至文件.简而言之,隐写术的主要目的是隐藏任何文件(通常是图像.音频或视频)中的预期信息,而不实际改变文件的外观,即文件外观看起来和以前一样. 在这篇文章中,我们将重点学习基于图像的隐写术,即在图像中隐藏秘密数据. 但在深入研究之前,让我们先看看图像由什么组成: 1.像素是图像的组成部分. 2.每个像素包含三个值:(红色.绿色.蓝色)也称为 RGB 值. 3.每个 RGB

  • Android 中TeaPickerView数据级联选择器功能的实例代码

    Github地址 YangsBryant/TeaPickerView (Github排版比较好,建议进入这里查看详情,如果觉得好,点个star吧!) 引入module allprojects { repositories { google() jcenter() maven { url 'https://www.jitpack.io' } } } implementation 'com.github.YangsBryant:TeaPickerView:1.0.2' 主要代码 public cla

  • Spark中的数据读取保存和累加器实例详解

    目录 数据读取与保存 Text文件 Sequence文件 Object对象文件 累加器 累加器概念 系统累加器 数据读取与保存 Text文件 对于 Text文件的读取和保存 ,其语法和实现是最简单的,因此我只是简单叙述一下这部分相关知识点,大家可以结合demo具体分析记忆. 1)基本语法 (1)数据读取:textFile(String) (2)数据保存:saveAsTextFile(String) 2)实现代码demo如下: object Operate_Text { def main(args

  • Android App中各种数据保存方式的使用实例总结

    少量数据保存之SharedPreferences接口实例 SharedPreferences数据保存主要是通过键值的方式存储在xml文件中 xml文件在data/此程序的包名/XX.xml. 格式: <?xml version='1.0' encoding='utf-8' standalone='yes' ?> <map> <int name="count" value="3" /> <string name="t

  • pytorch 图像预处理之减去均值,除以方差的实例

    如下所示: #coding=gbk ''' GPU上面的环境变化太复杂,这里我直接给出在笔记本CPU上面的运行时间结果 由于方式3需要将tensor转换到GPU上面,这一过程很消耗时间,大概需要十秒,故而果断抛弃这样的做法 img (168, 300, 3) sub div in numpy,time 0.0110 sub div in torch.tensor,time 0.0070 sub div in torch.tensor with torchvision.transforms,tim

  • pytorch数据预处理错误的解决

    出错: Traceback (most recent call last): File "train.py", line 305, in <module> train_model(model_conv, criterion, optimizer_conv, exp_lr_scheduler) File "train.py", line 145, in train_model for inputs, age_labels, gender_labels in

  • pytorch读取图像数据转成opencv格式实例

    pytorch读取图像数据转成opencv格式方法:先转成numpy通用的格式,再将其转换成opencv格式. pytorch读取的数据使用loaddata这类函数实现.pytorch网络输入图像的格式为(C, H, W),就是(通道数,高,宽)而numpy中图像的格式为(H,W,C). 那就将其通道调换一下.用到函数transpose. 转换方法如下 例如A 的格式为(c,h,w) 那么经过 A = A.transpose(1,2,0) 后就变成了(h,w,c)了 然后用语句 B= cv2.c

  • OpenCV提取图像中圆线上的数据具体流程

    目录 需求说明 具体流程 功能函数 C++测试代码 测试效果 总结 需求说明 在对图像进行处理时,经常会有这类需求:客户想要提取出图像中某条直线.圆线或者ROI区域内的感兴趣数据,进行重点关注.该需求在图像检测领域尤其常见.ROI区域一般搭配Rect即可完成提取,直线和圆线数据的提取没有现成的函数,需要自行实现. 直线的提取见: OpenCV获取图像中直线上的数据具体流程 而圆线的提取则是本文要将的内容,对圆线而言,将线上某点作为起点,沿顺时针或逆时针方向依次提取感兴趣数据,可放置在容器中.那么

随机推荐