PyTorch详解经典网络ResNet实现流程

目录
  • 简述
  • 残差结构
  • 18-layer 实现
  • 在数据集训练

简述

GoogleNet 和 VGG 等网络证明了,更深度的网络可以抽象出表达能力更强的特征,进而获得更强的分类能力。在深度网络中,随之网络深度的增加,每层输出的特征图分辨率主要是高和宽越来越小,而深度逐渐增加。

深度的增加理论上能够提升网络的表达能力,但是对于优化来说就会产生梯度消失的问题。在深度网络中,反向传播时,梯度从输出端向数据端逐层传播,传播过程中,梯度的累乘使得近数据段接近0值,使得网络的训练失效。

为了解决梯度消失问题,可以在网络中加入BatchNorm,激活函数换成ReLU,一定程度缓解了梯度消失问题。

深度增加的另一个问题就是网络的退化(Degradation of deep network)问题。即,在现有网络的基础上,增加网络的深度,理论上,只有训练到最佳情况,新网络的性能应该不会低于浅层的网络。因为,只要将新增加的层学习成恒等映射(identity mapping)就可以。换句话说,浅网络的解空间是深的网络的解空间的子集。但是由于Degradation问题,更深的网络并不一定好于浅层网络。

Residual模块的想法就是认为的让网络实现这种恒等映射。如图,残差结构在两层卷积的基础上,并行添加了一个分支,将输入直接加到最后的ReLU激活函数之前,如果两层卷积改变大量输入的分辨率和通道数,为了能够相加,可以在添加的分支上使用1x1卷积来匹配尺寸。

残差结构

ResNet网络有两种残差块,一种是两个3x3卷积,一种是1x1,3x3,1x1三个卷积网络串联成残差模块。

PyTorch 实现:

class Residual_1(nn.Module):
    r"""
    18-layer, 34-layer 残差块
    1. 使用了类似VGG的3×3卷积层设计;
    2. 首先使用两个相同输出通道数的3×3卷积层,后接一个批量规范化和ReLU激活函数;
    3. 加入跨过卷积层的通路,加到最后的ReLU激活函数前;
    4. 如果要匹配卷积后的输出的尺寸和通道数,可以在加入的跨通路上使用1×1卷积;
    """
    def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):
        r"""
        parameters:
            input_channels: 输入的通道上数
            num_channels: 输出的通道数
            use_1x1conv: 是否需要使用1x1卷积控制尺寸
            stride: 第一个卷积的步长
        """
        super().__init__()
        # 3×3卷积,strides控制分辨率是否缩小
        self.conv1 = nn.Conv2d(input_channels,
                               num_channels,
                               kernel_size=3,
                               padding=1,
                               stride=strides)
        # 3×3卷积,不改变分辨率
        self.conv2 = nn.Conv2d(num_channels,
                               num_channels,
                               kernel_size=3,
                               padding=1)
        # 使用 1x1 卷积变换输入的分辨率和通道
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels,
                                   num_channels,
                                   kernel_size=1,
                                   stride=strides)
        else:
            self.conv3 = None
        # 批量规范化层
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)
    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        # print(X.shape)
        Y += X
        return F.relu(Y)
class Residual_2(nn.Module):
    r"""
    50-layer, 101-layer, 152-layer 残差块
    1. 首先使用1x1卷积,ReLU激活函数;
    2. 然后用3×3卷积层,在接一个批量规范化,ReLU激活函数;
    3. 再接1x1卷积层;
    4. 加入跨过卷积层的通路,加到最后的ReLU激活函数前;
    5. 如果要匹配卷积后的输出的尺寸和通道数,可以在加入的跨通路上使用1×1卷积;
    """
    def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):
        r"""
        parameters:
            input_channels: 输入的通道上数
            num_channels: 输出的通道数
            use_1x1conv: 是否需要使用1x1卷积控制尺寸
            stride: 第一个卷积的步长
        """
        super().__init__()
        # 1×1卷积,strides控制分辨率是否缩小
        self.conv1 = nn.Conv2d(input_channels,
                               num_channels,
                               kernel_size=1,
                               padding=1,
                               stride=strides)
        # 3×3卷积,不改变分辨率
        self.conv2 = nn.Conv2d(num_channels,
                               num_channels,
                               kernel_size=3,
                               padding=1)
        # 1×1卷积,strides控制分辨率是否缩小
        self.conv3 = nn.Conv2d(input_channels,
                               num_channels,
                               kernel_size=1,
                               padding=1)
        # 使用 1x1 卷积变换输入的分辨率和通道
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels,
                                   num_channels,
                                   kernel_size=1,
                                   stride=strides)
        else:
            self.conv3 = None
        # 批量规范化层
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)
    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = F.relu(self.bn2(self.conv2(Y)))
        Y = self.conv3(Y)
        if self.conv3:
            X = self.conv3(X)
        # print(X.shape)
        Y += X
        return F.relu(Y)

ResNet有不同的网络层数,比较常用的是50-layer,101-layer,152-layer。他们都是由上述的残差模块堆叠在一起实现的。

以18-layer为例,层数是指:首先,conv_1 的一层7x7卷积,然后conv_2~conv_5四个模块,每个模块两个残差块,每个残差块有两层的3x3卷积组成,共4×2×2=16层,最后是一层分类层(fc),加总一起共1+16+1=18层。

18-layer 实现

首先定义由残差结构组成的模块:

# ResNet模块
def resnet_block(input_channels, num_channels, num_residuals, first_block=False):
    r"""残差块组成的模块"""
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blk.append(Residual_1(input_channels,
                                num_channels,
                                use_1x1conv=True,
                                strides=2))
        else:
            blk.append(Residual_1(num_channels, num_channels))
    return blk

定义18-layer的最开始的层:

# ResNet的前两层:
#    1. 输出通道数64, 步幅为2的7x7卷积层
#    2. 步幅为2的3x3最大汇聚层
conv_1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                   nn.BatchNorm2d(64),
                   nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

定义残差组模块:

# ResNet模块
conv_2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
conv_3 = nn.Sequential(*resnet_block(64, 128, 2))
conv_4 = nn.Sequential(*resnet_block(128, 256, 2))
conv_5 = nn.Sequential(*resnet_block(256, 512, 2))

ResNet 18-layer模型:

net = nn.Sequential(conv_1, conv_2, conv_3, conv_4, conv_5,
                    nn.AdaptiveAvgPool2d((1, 1)),
                    nn.Flatten(),
                    nn.Linear(512, 10))
# 观察模型各层的输出尺寸
X = torch.rand(size=(1, 1, 224, 224))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t', X.shape)

输出:

Sequential output shape:     torch.Size([1, 64, 56, 56])
Sequential output shape:     torch.Size([1, 64, 56, 56])
Sequential output shape:     torch.Size([1, 128, 28, 28])
Sequential output shape:     torch.Size([1, 256, 14, 14])
Sequential output shape:     torch.Size([1, 512, 7, 7])
AdaptiveAvgPool2d output shape:     torch.Size([1, 512, 1, 1])
Flatten output shape:     torch.Size([1, 512])
Linear output shape:     torch.Size([1, 10])

在数据集训练

def load_datasets_Cifar10(batch_size, resize=None):
    trans = [transforms.ToTensor()]
    if resize:
        transform = trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=trans, download=True)
    test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=trans, download=True)
    print("Cifar10 下载完成...")
    return (torch.utils.data.DataLoader(train_data, batch_size, shuffle=True),
            torch.utils.data.DataLoader(test_data, batch_size, shuffle=False))
def load_datasets_FashionMNIST(batch_size, resize=None):
    trans = [transforms.ToTensor()]
    if resize:
        transform = trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    train_data = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
    test_data = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)
    print("FashionMNIST 下载完成...")
    return (torch.utils.data.DataLoader(train_data, batch_size, shuffle=True),
            torch.utils.data.DataLoader(test_data, batch_size, shuffle=False))
def load_datasets(dataset, batch_size, resize):
    if dataset == "Cifar10":
        return load_datasets_Cifar10(batch_size, resize=resize)
    else:
        return load_datasets_FashionMNIST(batch_size, resize=resize)
train_iter, test_iter = load_datasets("", 128, 224) # Cifar10

到此这篇关于PyTorch详解经典网络ResNet实现流程的文章就介绍到这了,更多相关PyTorch ResNet内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • Pytorch修改ResNet模型全连接层进行直接训练实例

    之前在用预训练的ResNet的模型进行迁移训练时,是固定除最后一层的前面层权重,然后把全连接层输出改为自己需要的数目,进行最后一层的训练,那么现在假如想要只是把 最后一层的输出改一下,不需要加载前面层的权重,方法如下: model = torchvision.models.resnet18(pretrained=False) num_fc_ftr = model.fc.in_features model.fc = torch.nn.Linear(num_fc_ftr, 224) model =

  • 聊聊基于pytorch实现Resnet对本地数据集的训练问题

    目录 1.dataset.py(先看代码的总体流程再看介绍) 2.network.py 3.train.py 4.结果与总结 本文是使用pycharm下的pytorch框架编写一个训练本地数据集的Resnet深度学习模型,其一共有两百行代码左右,分成mian.py.network.py.dataset.py以及train.py文件,功能是对本地的数据集进行分类.本文介绍逻辑是总分形式,即首先对总流程进行一个概括,然后分别介绍每个流程中的实现过程(代码+流程图+文字的介绍). 对于整个项目的流程首

  • pytorch实现用Resnet提取特征并保存为txt文件的方法

    接触pytorch一天,发现pytorch上手的确比TensorFlow更快.可以更方便地实现用预训练的网络提特征. 以下是提取一张jpg图像的特征的程序: # -*- coding: utf-8 -*- import os.path import torch import torch.nn as nn from torchvision import models, transforms from torch.autograd import Variable import numpy as np

  • pytorch教程resnet.py的实现文件源码分析

    目录 调用pytorch内置的模型的方法 解读模型源码Resnet.py 包含的库文件 该库定义了6种Resnet的网络结构 每种网络都有训练好的可以直接用的.pth参数文件 Resnet中大多使用3*3的卷积定义如下 如何定义不同大小的Resnet网络 定义Resnet18 定义Resnet34 Resnet类 网络的forward过程 残差Block连接是如何实现的 调用pytorch内置的模型的方法 import torchvision model = torchvision.models

  • pytorch实现ResNet结构的实例代码

    1.ResNet的创新 现在重新稍微系统的介绍一下ResNet网络结构. ResNet结构首先通过一个卷积层然后有一个池化层,然后通过一系列的残差结构,最后再通过一个平均池化下采样操作,以及一个全连接层的得到了一个输出.ResNet网络可以达到很深的层数的原因就是不断的堆叠残差结构而来的. 1)亮点 网络中的亮点 : 超深的网络结构( 突破1000 层) 提出residual 模块 使用Batch Normalization 加速训练( 丢弃dropout) 但是,一般来说,并不是一直的加深神经

  • PyTorch实现ResNet50、ResNet101和ResNet152示例

    PyTorch: https://github.com/shanglianlm0525/PyTorch-Networks import torch import torch.nn as nn import torchvision import numpy as np print("PyTorch Version: ",torch.__version__) print("Torchvision Version: ",torchvision.__version__) _

  • 人工智能学习pyTorch的ResNet残差模块示例详解

    目录 1.定义ResNet残差模块 ①各层的定义 ②前向传播 2.ResNet18的实现 ①各层的定义 ②前向传播 3.测试ResNet18 1.定义ResNet残差模块 一个block中,有两个卷积层,之后的输出还要和输入进行相加.因此一个block的前向流程如下: 输入x→卷积层→数据标准化→ReLU→卷积层→数据标准化→数据和x相加→ReLU→输出out 中间加上了数据的标准化(通过nn.BatchNorm2d实现),可以使得效果更好一些. ①各层的定义 ②前向传播 在前向传播中输入x,过

  • PyTorch详解经典网络ResNet实现流程

    目录 简述 残差结构 18-layer 实现 在数据集训练 简述 GoogleNet 和 VGG 等网络证明了,更深度的网络可以抽象出表达能力更强的特征,进而获得更强的分类能力.在深度网络中,随之网络深度的增加,每层输出的特征图分辨率主要是高和宽越来越小,而深度逐渐增加. 深度的增加理论上能够提升网络的表达能力,但是对于优化来说就会产生梯度消失的问题.在深度网络中,反向传播时,梯度从输出端向数据端逐层传播,传播过程中,梯度的累乘使得近数据段接近0值,使得网络的训练失效. 为了解决梯度消失问题,可

  • PyTorch详解经典网络种含并行连结的网络GoogLeNet实现流程

    目录 1. Inception块 2. 构造 GoogLeNet 网络 3. FashionMNIST训练测试 含并行连结的网络 GoogLeNet 在GoogleNet出现值前,流行的网络结构使用的卷积核从1×1到11×11,卷积核的选择并没有太多的原因.GoogLeNet的提出,说明有时候使用多个不同大小的卷积核组合是有利的. import torch from torch import nn from torch.nn import functional as F 1. Inception

  • 详解Java网络编程

    一.网络编程 1.1.概述 1.计算机网络是通过传输介质.通信设施和网络通信协议,把分散在不同地点的计算机设备互连起来,实现资源共享和数据传输的系统.网络编程就就是编写程序使联网的两个(或多个)设备(例如计算机)之间进行数据传输.Java语言对网络编程提供了良好的支持,通过其提供的接口我们可以很方便地进行网络编程. 2.Java是 Internet 上的语言,它从语言级上提供了对网络应用程 序的支持,程序员能够很容易开发常见的网络应用程序. 3.Java提供的网络类库,可以实现无痛的网络连接,联

  • 详解Git合并分支的流程步骤

    正常合并分支dev到master流程: (合并到其他分支类似哈) 1.要合并的dev分支先更新提交所有文件 注意: 如果不需要提交的本地化修改文件的话,最好不要提交上去.临时备份然后删掉或者撤回. 进入项目根目录,然后执行: git add . git commit -m '提交所有dev分支的文件' git push -u origin dev 2.切换到master分支 git checkout master 3.更新master代码到最新 git pull origin master 4.

  • 详解Android Activity的启动流程

    前言 activity启动的流程分为两部分:一是在activity中通过startActivity(Intent intent)方法启动一个Activity:二是我们在桌面通过点击应用图标启动一个App然后显示Activity:第二种方式相较于第一种方式更加全面,所以本文会以第二种流程来分析. 简要 我们手机的桌面是一个叫做Launcher的Activity,它罗列了手机中的应用图标,图标中包含安装apk时解析的应用默认启动页等信息.在点击应用图标时,即将要启动的App和Launcher.AMS

  • 详解Spring容器的使用流程

    前言 Spring容器的API有 BeanFactory 和 ApplicationContext 两大类,他们都是顶级接口.其中ApplicationContext 是 BeanFactory 的子接口.对于两者的说明请参考面试题讲解Spring容器部分.我们主要使用 ApplicationContext 应用上下文接口. 一.主要流程 二.开发步骤 2.1 准备Maven项目及环境 首先创建一个Maven项目,名称为 spring-study ,以下是项目的maven配置文件 pom.xml

  • 详解JavaScript引擎V8执行流程

    目录 一.V8来源 二.V8的服务对象 三.V8的早期架构 四.V8早期架构的缺陷 五.V8的现有架构 六.V8的词法分析和语法分析 七.V8 AST抽象语法树 八.字节码 九.Turbofan 一.V8来源 V8的名字来源于汽车的"V型8缸发动机"(V8发动机).V8发动机主要是美国发展起来,因为马力十足而广为人知.V8引擎的命名是Google向用户展示它是一款强力并且高速的JavaScript引擎. V8未诞生之前,早期主流的JavaScript引擎是JavaScriptCore引

  • 详解python网络进程

    目录 一.多任务编程 二.进程 三.os.fork创建进程 3.1.进程ID和退出函数 四.孤儿和僵尸 4.1.孤儿进程 4.2.僵尸进程 4.3.如何避免僵尸进程的产生 五.Multiprocessing创建进程 5.1.multiprocessing进程属性 六.进程池 七.进程间通信(IPC) 7.1.管道通信(Pipe) 7.2.消息队列 7.3.共享内存 7.4.信号量(信号灯集) 一.多任务编程 意义:充分利用计算机的资源提高程序的运行效率 定义:通过应用程序利用计算机多个核心,达到

  • springboot与vue详解实现短信发送流程

    目录 一.前期工作 1.开启邮箱服务 2.导入依赖 3.配置application.yaml文件 二.实现流程 1.导入数据库 2.后端实现 编写实体类 编写工具类ResultVo 编写dao层接口 配置dao层接口的数据库操作 编写service层接口 编写service层的实现方法 实现controller层 Test代码 前端页面的实现 运行截图+sql图 总结 一.前期工作 1.开启邮箱服务 开启邮箱的POP3/SMTP服务(这里以qq邮箱为例,网易等都是一样的) 2.导入依赖 在spr

  • Spring Boot详解创建和运行基础流程

    目录 1. 初始 Spring Boot 1.1 什么是Spring Boot 1.2 Spring Boot 的优点 2. 创建 Spring Boot 项目(Idea) 2.1 首先安装 Spring Assistant 插件 2.2 重启Idea-New Project ① 点击 Spring Assistant 直接Next就可以了 ② Next 之后的页面介绍 ③ 引入依赖, 选择Spring Boot的版本 ④ 选择项目名称和保存路径 ⑤ Spring Boot 项目创建完成 3.

随机推荐