PyTorch详解经典网络种含并行连结的网络GoogLeNet实现流程

目录
  • 1. Inception块
  • 2. 构造 GoogLeNet 网络
  • 3. FashionMNIST训练测试

含并行连结的网络 GoogLeNet

在GoogleNet出现值前,流行的网络结构使用的卷积核从1×1到11×11,卷积核的选择并没有太多的原因。GoogLeNet的提出,说明有时候使用多个不同大小的卷积核组合是有利的。

import torch
from torch import nn
from torch.nn import functional as F

1. Inception块

Inception块是 GoogLeNet 的基本组成单元。Inception 块由四条并行的路径组成,每个路径使用不同大小的卷积核:

路径1:使用 1×1 卷积层;

路径2:先对输出执行 1×1 卷积层,来减少通道数,降低模型复杂性,然后接 3×3 卷积层;

路径3:先对输出执行 1×1 卷积层,然后接 5×5 卷积层;

路径4:使用 3×3 最大汇聚层,然后使用 1×1 卷积层;

在各自路径中使用合适的 padding ,使得各个路径的输出拥有相同的高和宽,然后将每条路径的输出在通道维度上做连结,作为 Inception 块的最终输出.

class Inception(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Inception, self).__init__()
        # 路径1
        c1, c2, c3, c4 = out_channels
        self.route1_1 = nn.Conv2d(in_channels, c1, kernel_size=1)
        # 路径2
        self.route2_1 = nn.Conv2d(in_channels, c2[0], kernel_size=1)
        self.route2_2 = nn.Conv2d(c2[0], c2[1], kernel_size=3, padding=1)
        # 路径3
        self.route3_1 = nn.Conv2d(in_channels, c3[0], kernel_size=1)
        self.route3_2 = nn.Conv2d(c3[0], c3[1], kernel_size=5, padding=2)
        # 路径4
        self.route4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.route4_2 = nn.Conv2d(in_channels, c4, kernel_size=1)
    def forward(self, x):
        x1 = F.relu(self.route1_1(x))
        x2 = F.relu(self.route2_2(F.relu(self.route2_1(x))))
        x3 = F.relu(self.route3_2(F.relu(self.route3_1(x))))
        x4 = F.relu(self.route4_2(self.route4_1(x)))
        return torch.cat((x1, x2, x3, x4), dim=1) 

2. 构造 GoogLeNet 网络

顺序定义 GoogLeNet 的模块。

第一个模块,顺序使用三个卷积层。

# 模型的第一个模块
b1 = nn.Sequential(
    nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
    nn.Conv2d(64, 64, kernel_size=1),
    nn.ReLU(),
    nn.Conv2d(64, 192, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                   )

第二个模块,使用两个Inception模块。

# Inception组成的第二个模块
b2 = nn.Sequential(
    Inception(192, (64, (96, 128), (16, 32), 32)),
    Inception(256, (128, (128, 192), (32, 96), 64)),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                    )

第三个模块,串联五个Inception模块。

# Inception组成的第三个模块
b3 = nn.Sequential(
    Inception(480, (192, (96, 208), (16, 48), 64)),
    Inception(512, (160, (112, 224), (24, 64), 64)),
    Inception(512, (128, (128, 256), (24, 64), 64)),
    Inception(512, (112, (144, 288), (32, 64), 64)),
    Inception(528, (256, (160, 320), (32, 128), 128)),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                    )

第四个模块,传来两个Inception模块。

GoogLeNet使用 avg pooling layer 代替了 fully-connected layer。一方面降低了维度,另一方面也可以视为对低层特征的组合。

# Inception组成的第四个模块
b4 = nn.Sequential(
    Inception(832, (256, (160, 320), (32, 128), 128)),
    Inception(832, (384, (192, 384), (48, 128), 128)),
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten()
                    )
net = nn.Sequential(b1, b2, b3, b4, nn.Linear(1024, 10))
x = torch.randn(1, 1, 96, 96)
for layer in net:
    x = layer(x)
    print(layer.__class__.__name__, "output shape: ", x.shape)

输出:

Sequential output shape:  torch.Size([1, 192, 28, 28])
Sequential output shape:  torch.Size([1, 480, 14, 14])
Sequential output shape:  torch.Size([1, 832, 7, 7])
Sequential output shape:  torch.Size([1, 1024])
Linear output shape:  torch.Size([1, 10])

3. FashionMNIST训练测试

def load_datasets_Cifar10(batch_size, resize=None):
    trans = [transforms.ToTensor()]
    if resize:
        transform = trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=trans, download=True)
    test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=trans, download=True)
    print("Cifar10 下载完成...")
    return (torch.utils.data.DataLoader(train_data, batch_size, shuffle=True),
            torch.utils.data.DataLoader(test_data, batch_size, shuffle=False))
def load_datasets_FashionMNIST(batch_size, resize=None):
    trans = [transforms.ToTensor()]
    if resize:
        transform = trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    train_data = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
    test_data = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)
    print("FashionMNIST 下载完成...")
    return (torch.utils.data.DataLoader(train_data, batch_size, shuffle=True),
            torch.utils.data.DataLoader(test_data, batch_size, shuffle=False))
def load_datasets(dataset, batch_size, resize):
    if dataset == "Cifar10":
        return load_datasets_Cifar10(batch_size, resize=resize)
    else:
        return load_datasets_FashionMNIST(batch_size, resize=resize)
train_iter, test_iter = load_datasets("", 128, 96) # Cifar10

训练结果:

到此这篇关于PyTorch详解经典网络种含并行连结的网络GoogLeNet实现流程的文章就介绍到这了,更多相关PyTorch GoogLeNet内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • Pytorch实现GoogLeNet的方法

    GoogLeNet也叫InceptionNet,在2014年被提出,如今已到V4版本.GoogleNet比VGGNet具有更深的网络结构,一共有22层,但是参数比AlexNet要少12倍,但是计算量是AlexNet的4倍,原因就是它采用很有效的Inception模块,并且没有全连接层. 最重要的创新点就在于使用inception模块,通过使用不同维度的卷积提取不同尺度的特征图.左图是最初的Inception模块,右图是使用的1×1得卷积对左图的改进,降低了输入的特征图维度,同时降低了网络的参数量

  • PyTorch详解经典网络种含并行连结的网络GoogLeNet实现流程

    目录 1. Inception块 2. 构造 GoogLeNet 网络 3. FashionMNIST训练测试 含并行连结的网络 GoogLeNet 在GoogleNet出现值前,流行的网络结构使用的卷积核从1×1到11×11,卷积核的选择并没有太多的原因.GoogLeNet的提出,说明有时候使用多个不同大小的卷积核组合是有利的. import torch from torch import nn from torch.nn import functional as F 1. Inception

  • PyTorch详解经典网络ResNet实现流程

    目录 简述 残差结构 18-layer 实现 在数据集训练 简述 GoogleNet 和 VGG 等网络证明了,更深度的网络可以抽象出表达能力更强的特征,进而获得更强的分类能力.在深度网络中,随之网络深度的增加,每层输出的特征图分辨率主要是高和宽越来越小,而深度逐渐增加. 深度的增加理论上能够提升网络的表达能力,但是对于优化来说就会产生梯度消失的问题.在深度网络中,反向传播时,梯度从输出端向数据端逐层传播,传播过程中,梯度的累乘使得近数据段接近0值,使得网络的训练失效. 为了解决梯度消失问题,可

  • 一文详解Dart如何实现多任务并行

    目录 Isolate(隔离区域) async/await Stream Compute Function Isolate(隔离区域) Dart 是一种支持多任务并行的编程语言,它提供了多种机制来实现并发和并行.下面是 Dart 实现多任务并行的几种方式: Dart 中的 Isolate 是一种轻量级的并发机制,类似于线程.每个隔离区域都是独立的内存空间,每个隔离区域都有自己的内存空间和执行线程,因此不同的隔离区域之间可以独立地执行代码,每个隔离区都在自己的核心上运行,不会阻塞其他 Isolate

  • spring、mybatis 配置方式详解(常用两种方式)

    在之前的文章中总结了三种方式,但是有两种是注解sql的,这种方式比较混乱所以大家不怎么使用,下面总结一下常用的两种总结方式: 一. 动态代理实现 不用写dao的实现类 这种方式比较简单,不用实现dao层,只需要定义接口就可以了,这里只是为了记录配置文件所以程序写的很简单: 1.整体结构图: 2.三个配置文件以及一个映射文件 (1).程序入口以及前端控制器配置 web.xml <?xml version="1.0" encoding="UTF-8"?> &

  • 详解template标签用法(含vue中的用法总结)

    一.html5中的template标签 html中的template标签中的内容在页面中不会显示.但是在后台查看页面DOM结构存在template标签.这是因为template标签天生不可见,它设置了display:none;属性. <!--当前页面只显示"我是自定义表现abc"这个内容,不显示"我是template",这是因为template标签天生不可见--> <template><div>我是template</div

  • 详解Golang五种原子性操作的用法

    目录 Go 语言提供了哪些原子操作 互斥锁跟原子操作的区别 比较并交换 atomic.Value保证任意值的读写安全 总结 本文我们详细聊一下Go语言的原子操作的用法,啥是原子操作呢?顾名思义,原子操作就是具备原子性的操作... 是不是感觉说了跟没说一样,原子性的解释如下: 一个或者多个操作在 CPU 执行的过程中不被中断的特性,称为原子性(atomicity) .这些操作对外表现成一个不可分割的整体,他们要么都执行,要么都不执行,外界不会看到他们只执行到一半的状态. CPU执行一系列操作时不可

  • 详解mybatis三种分页方式

    目录 前言 一.Limit分页 二.RowBounds分页(不推荐使用) 三.Mybatis_PageHelper分页插件 总结: 前言 分页是我们在开发中绕不过去的一个坎!当你的数据量大了的时候,一次性将所有数据查出来不现实,所以我们一般都是分页查询的,减轻服务端的压力,提升了速度和效率!也减轻了前端渲染的压力! 注意:由于 java 允许的最大整数为 2147483647,所以 limit 能使用的最大整数也是 2147483647,一次性取出大量数据可能引起内存溢出,所以在大数据查询场合慎

  • 详解SQL四种语言:DDL DML DCL TCL

    看到很多人讨论SQL还分为四种类型,在这里知识普及一下,并总结下他们的区别吧. 1. DDL – Data Definition Language 数据库定义语言:定义数据库的结构. 其主要命令有CREATE,ALTER,DROP等,下面用例子详解.该语言不需要commit,因此慎重. CREATE – to create objects in the database   在数据库创建对象 例: CREATE DATABASE test; // 创建一个名为test的数据库 ALTER – a

  • 详解java 三种调用机制(同步、回调、异步)

    1:同步调用:一种阻塞式调用,调用方要等待对方执行完毕才返回,它是一种单向调用 2:回调:一种双向调用模式,也就是说,被调用方在接口被调用时也会调用对方的接口: 3:异步调用:一种类似消息或事件的机制,不过它的调用方向刚好相反,接口的服务在收到某种讯息或发生某种事件时,会主动通知客户方(即调用客户方的接口 具体说来:就是A类中调用B类中的某个方法C,然后B类中反过来调用A类中的方法D,D这个方法就叫回调方法, 实例1:使用java中Timer来在给定时间间隔发送通知,每隔十秒打印一次数据 Tim

  • 把JSON数据格式转换为Python的类对象方法详解(两种方法)

    JOSN字符串转换为自定义类实例对象 有时候我们有这种需求就是把一个JSON字符串转换为一个具体的Python类的实例,比如你接收到这样一个JSON字符串如下: {"Name": "Tom", "Sex": "Male", "BloodType": "A", "Hobbies": ["篮球", "足球"]} 我需要把这个转换为具

随机推荐