简单易懂Pytorch实战实例VGG深度网络

简单易懂Pytorch实战实例VGG深度网络

模型VGG,数据集cifar。对照这份代码走一遍,大概就知道整个pytorch的运行机制。
来源
定义模型:

'''VGG11/13/16/19 in Pytorch.'''
import torch
import torch.nn as nn
from torch.autograd import Variable

cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

# 模型需继承nn.Module
class VGG(nn.Module):
# 初始化参数:
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, 10)

# 模型计算时的前向过程,也就是按照这个过程进行计算
    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

# net = VGG('VGG11')
# x = torch.randn(2,3,32,32)
# print(net(Variable(x)).size())

定义训练过程:

'''Train CIFAR10 with PyTorch.'''
from __future__ import print_function

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

from models import *
from utils import progress_bar
from torch.autograd import Variable

# 获取参数
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
args = parser.parse_args()

use_cuda = torch.cuda.is_available()
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

# 获取数据集,并先进行预处理
print('==> Preparing data..')
# 图像预处理和增强
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 继续训练模型或新建一个模型
if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/ckpt.t7')
    net = checkpoint['net']
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']
else:
    print('==> Building model..')
    net = VGG('VGG16')
    # net = ResNet18()
    # net = PreActResNet18()
    # net = GoogLeNet()
    # net = DenseNet121()
    # net = ResNeXt29_2x64d()
    # net = MobileNet()
    # net = MobileNetV2()
    # net = DPN92()
    # net = ShuffleNetG2()
    # net = SENet18()

# 如果GPU可用,使用GPU
if use_cuda:
    # move param and buffer to GPU
    net.cuda()
    # parallel use GPU
    net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()-1))
    # speed up slightly
    cudnn.benchmark = True

# 定义度量和优化
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

# 训练阶段
def train(epoch):
    print('\nEpoch: %d' % epoch)
    # switch to train mode
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    # batch 数据
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        # 将数据移到GPU上
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        # 先将optimizer梯度先置为0
        optimizer.zero_grad()
        # Variable表示该变量属于计算图的一部分,此处是图计算的开始处。图的leaf variable
        inputs, targets = Variable(inputs), Variable(targets)
        # 模型输出
        outputs = net(inputs)
        # 计算loss,图的终点处
        loss = criterion(outputs, targets)
        # 反向传播,计算梯度
        loss.backward()
        # 更新参数
        optimizer.step()
        # 注意如果你想统计loss,切勿直接使用loss相加,而是使用loss.data[0]。因为loss是计算图的一部分,如果你直接加loss,代表total loss同样属于模型一部分,那么图就越来越大
        train_loss += loss.data[0]
        # 数据统计
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

# 测试阶段
def test(epoch):
    global best_acc
    # 先切到测试模型
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(testloader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = Variable(inputs, volatile=True), Variable(targets)
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        # loss is variable , if add it(+=loss) directly, there will be a bigger ang bigger graph.
        test_loss += loss.data[0]
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

# Save checkpoint.
    # 保存模型
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.module if use_cuda else net,
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.t7')
        best_acc = acc

# 运行模型
for epoch in range(start_epoch, start_epoch+200):
    train(epoch)
    test(epoch)
    # 清除部分无用变量
    torch.cuda.empty_cache()

运行:

新模型:
python main.py --lr=0.01
旧模型继续训练:
python main.py --resume --lr=0.01

一些utility:

'''Some helper functions for PyTorch, including:
    - get_mean_and_std: calculate the mean and std value of dataset.
    - msr_init: net parameter initialization.
    - progress_bar: progress bar mimic xlua.progress.
'''
import os
import sys
import time
import math

import torch.nn as nn
import torch.nn.init as init

def get_mean_and_std(dataset):
    '''Compute the mean and std value of dataset.'''
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    print('==> Computing mean and std..')
    for inputs, targets in dataloader:
        for i in range(3):
            mean[i] += inputs[:,i,:,:].mean()
            std[i] += inputs[:,i,:,:].std()
    mean.div_(len(dataset))
    std.div_(len(dataset))
    return mean, std

def init_params(net):
    '''Init layer parameters.'''
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            init.kaiming_normal(m.weight, mode='fan_out')
            if m.bias:
                init.constant(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            init.constant(m.weight, 1)
            init.constant(m.bias, 0)
        elif isinstance(m, nn.Linear):
            init.normal(m.weight, std=1e-3)
            if m.bias:
                init.constant(m.bias, 0)

_, term_width = os.popen('stty size', 'r').read().split()
term_width = int(term_width)

TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
    global last_time, begin_time
    if current == 0:
        begin_time = time.time()  # Reset for new bar.

cur_len = int(TOTAL_BAR_LENGTH*current/total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

sys.stdout.write(' [')
    for i in range(cur_len):
        sys.stdout.write('=')
    sys.stdout.write('>')
    for i in range(rest_len):
        sys.stdout.write('.')
    sys.stdout.write(']')

cur_time = time.time()
    step_time = cur_time - last_time
    last_time = cur_time
    tot_time = cur_time - begin_time

L = []
    L.append('  Step: %s' % format_time(step_time))
    L.append(' | Tot: %s' % format_time(tot_time))
    if msg:
        L.append(' | ' + msg)

msg = ''.join(L)
    sys.stdout.write(msg)
    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
        sys.stdout.write(' ')

# Go back to the center of the bar.
    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
        sys.stdout.write('\b')
    sys.stdout.write(' %d/%d ' % (current+1, total))

if current < total-1:
        sys.stdout.write('\r')
    else:
        sys.stdout.write('\n')
    sys.stdout.flush()

def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)

f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return f
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持我们。

(0)

相关推荐

  • 简单易懂Pytorch实战实例VGG深度网络

    简单易懂Pytorch实战实例VGG深度网络 模型VGG,数据集cifar.对照这份代码走一遍,大概就知道整个pytorch的运行机制. 来源 定义模型: '''VGG11/13/16/19 in Pytorch.''' import torch import torch.nn as nn from torch.autograd import Variable cfg = {     'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M

  • 用pytorch的nn.Module构造简单全链接层实例

    python版本3.7,用的是虚拟环境安装的pytorch,这样随便折腾,不怕影响其他的python框架 1.先定义一个类Linear,继承nn.Module import torch as t from torch import nn from torch.autograd import Variable as V class Linear(nn.Module): '''因为Variable自动求导,所以不需要实现backward()''' def __init__(self, in_feat

  • 基于vue.js路由参数的实例讲解——简单易懂

    vue中,我们构建单页面应用时候,一定必不可少用到vue-router vue-router 就是我们的路由,这个由vue官方提供的插件 首先在我们项目中安装vue-router路由依赖 第一种,我们提供命令行来安装 npm install vue-router --save 第二种,我们直接去官方github下载 https://github.com/vuejs/vue-router 路由参数设置 1,实例化一个路由,然后路由映射表中的地址带参数,这个参数就是路由的参数 接着给映射表中的路由设

  • Python 网络爬虫--关于简单的模拟登录实例讲解

    和获取网页上的信息不同,想要进行模拟登录还需要向服务器发送一些信息,如账号.密码等等. 模拟登录一个网站大致分为这么几步: 1.先将登录网站的隐藏信息找到,并将其内容先进行保存(由于我这里登录的网站并没有额外信息,所以这里没有进行信息筛选保存) 2.将信息进行提交 3.获取登录后的信息 先给上源码 <span style="font-size: 14px;"># -*- coding: utf-8 -*- import requests def login(): sessi

  • vue递归组件实战之简单树形控件实例代码

    1.递归组件-简单树形控件预览及问题 在编写树形组件时遇到的问题: 组件如何才能递归调用? 递归组件点击事件如何传递? 2.树形控件基本结构及样式 <template> <ul class="vue-tree"> <li class="tree-item"> <div class="tree-content"><!--节点内容--> <div class="expand-

  • Python中的多线程实例(简单易懂)

    目录 1.python中显示当前线程信息的属性和方法 2.添加一个线程 3.线程中的join函数 4.使用Queue存储线程的结果 5.线程锁lock 前言: 多线程简单理解就是:一个CPU,也就是单核,将时间切成一片一片的,CPU轮转着去处理一件一件的事情,到了规定的时间片就处理下一件事情. 1.python中显示当前线程信息的属性和方法 # coding:utf-8 # 导入threading包 import threading if __name__ == "__main__":

  • Java设计模式之简单工厂 工厂方法 抽象工厂深度总结

    目录 工厂模式介绍 好处 常见的应用 简单工厂(Simple Factory) 适用场景 角色分配: 应用案例: 优缺点: 简单工厂实现: 工厂方法(Factory Method) 适用场景 角色分配: 应用案例: 优缺点: 工厂方法实现: 抽象工厂(Abstract Factory) 适用场景 角色分配 应用案例: 优缺点: 抽象工厂实现 抽象工厂终极改进(反射+配置文件+简单工厂) 工厂模式介绍 工厂模式也是非常常见的设计模式之一,其属于创建型模式.工厂模式分类:简单工厂(Simple Fa

  • Pytorch模型定义与深度学习自查手册

    目录 定义神经网络 权重初始化 方法1:net.apply(weights_init) 方法2:在网络初始化的时候进行参数初始化 常用的操作 利用nn.Parameter()设计新的层 nn.Flatten nn.Sequential 常用的层 全连接层nn.Linear() torch.nn.Dropout 卷积torch.nn.ConvNd() 池化 最大池化torch.nn.MaxPoolNd() 均值池化torch.nn.AvgPoolNd() 反池化 最大值反池化nn.MaxUnpoo

  • 简单易懂Java反射的setAccessible()方法

    目录 概要 一. 什么是Java的访问检查 二. setAccessible() 方法介绍 前言:在使用Field类的对象访问我自定义的Employee类对象的name域时,抛出异常illegalAccessException.查询原因为:在访问name域时,Java进行了访问检查,发现该域是private修饰的,不能直接访问,因此抛出异常. 概要 本文首先详细介绍访问检查的概念,然后介绍关于反射的运行时访问检查,并说明可能存在的问题.最后介绍可以通过setAccessible方法屏蔽或者说禁用

  • Python3使用PyQt5制作简单的画板/手写板实例

    1.前言 版本:Python3.6.1 + PyQt5 写一个程序的时候需要用到画板/手写板,只需要最简单的那种.原以为网上到处都是,结果找了好几天,都没有找到想要的结果. 网上的要么是非python版的qt程序(要知道qt版本之间差异巨大,还是非同一语言的),改写难度太大.要么是PyQt4的老程序,很多都已经不能在PyQt5上运行了.要么是大神写的特别复杂的程序,简直是直接做出了一个Windows自带的画图版,只能膜拜~ 于是我只能在众多代码中慢慢寻找自己需要的那一小部分,然后不断地拼凑,不断

随机推荐