PyTorch快速搭建神经网络及其保存提取方法详解

有时候我们训练了一个模型, 希望保存它下次直接使用,不需要下次再花时间去训练 ,本节我们来讲解一下PyTorch快速搭建神经网络及其保存提取方法详解

一、PyTorch快速搭建神经网络方法

先看实验代码:

import torch
import torch.nn.functional as F 

# 方法1,通过定义一个Net类来建立神经网络
class Net(torch.nn.Module):
  def __init__(self, n_feature, n_hidden, n_output):
    super(Net, self).__init__()
    self.hidden = torch.nn.Linear(n_feature, n_hidden)
    self.predict = torch.nn.Linear(n_hidden, n_output) 

  def forward(self, x):
    x = F.relu(self.hidden(x))
    x = self.predict(x)
    return x 

net1 = Net(2, 10, 2)
print('方法1:\n', net1) 

# 方法2 通过torch.nn.Sequential快速建立神经网络结构
net2 = torch.nn.Sequential(
  torch.nn.Linear(2, 10),
  torch.nn.ReLU(),
  torch.nn.Linear(10, 2),
  )
print('方法2:\n', net2)
# 经验证,两种方法构建的神经网络功能相同,结构细节稍有不同 

'''''
方法1:
 Net (
 (hidden): Linear (2 -> 10)
 (predict): Linear (10 -> 2)
)
方法2:
 Sequential (
 (0): Linear (2 -> 10)
 (1): ReLU ()
 (2): Linear (10 -> 2)
)
''' 

先前学习了通过定义一个Net类来构建神经网络的方法,classNet中首先通过super函数继承torch.nn.Module模块的构造方法,再通过添加属性的方式搭建神经网络各层的结构信息,在forward方法中完善神经网络各层之间的连接信息,然后再通过定义Net类对象的方式完成对神经网络结构的构建。

构建神经网络的另一个方法,也可以说是快速构建方法,就是通过torch.nn.Sequential,直接完成对神经网络的建立。

两种方法构建得到的神经网络结构完全相同,都可以通过print函数来打印输出网络信息,不过打印结果会有些许不同。

二、PyTorch的神经网络保存和提取

在学习和研究深度学习的时候,当我们通过一定时间的训练,得到了一个比较好的模型的时候,我们当然希望将这个模型及模型参数保存下来,以备后用,所以神经网络的保存和模型参数提取重载是很有必要的。

首先,我们需要在需要保存网路结构及其模型参数的神经网络的定义、训练部分之后通过torch.save()实现对网络结构和模型参数的保存。有两种保存方式:一是保存年整个神经网络的的结构信息和模型参数信息,save的对象是网络net;二是只保存神经网络的训练模型参数,save的对象是net.state_dict(),保存结果都以.pkl文件形式存储。

对应上面两种保存方式,重载方式也有两种。对应第一种完整网络结构信息,重载的时候通过torch.load(‘.pkl')直接初始化新的神经网络对象即可。对应第二种只保存模型参数信息,需要首先搭建相同的神经网络结构,通过net.load_state_dict(torch.load('.pkl'))完成模型参数的重载。在网络比较大的时候,第一种方法会花费较多的时间。

代码实现:

import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt 

torch.manual_seed(1) # 设定随机数种子 

# 创建数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2*torch.rand(x.size())
x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) 

# 将待保存的神经网络定义在一个函数中
def save():
  # 神经网络结构
  net1 = torch.nn.Sequential(
    torch.nn.Linear(1, 10),
    torch.nn.ReLU(),
    torch.nn.Linear(10, 1),
    )
  optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
  loss_function = torch.nn.MSELoss() 

  # 训练部分
  for i in range(300):
    prediction = net1(x)
    loss = loss_function(prediction, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step() 

  # 绘图部分
  plt.figure(1, figsize=(10, 3))
  plt.subplot(131)
  plt.title('net1')
  plt.scatter(x.data.numpy(), y.data.numpy())
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 

  # 保存神经网络
  torch.save(net1, '7-net.pkl')           # 保存整个神经网络的结构和模型参数
  torch.save(net1.state_dict(), '7-net_params.pkl') # 只保存神经网络的模型参数 

# 载入整个神经网络的结构及其模型参数
def reload_net():
  net2 = torch.load('7-net.pkl')
  prediction = net2(x) 

  plt.subplot(132)
  plt.title('net2')
  plt.scatter(x.data.numpy(), y.data.numpy())
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 

# 只载入神经网络的模型参数,神经网络的结构需要与保存的神经网络相同的结构
def reload_params():
  # 首先搭建相同的神经网络结构
  net3 = torch.nn.Sequential(
    torch.nn.Linear(1, 10),
    torch.nn.ReLU(),
    torch.nn.Linear(10, 1),
    ) 

  # 载入神经网络的模型参数
  net3.load_state_dict(torch.load('7-net_params.pkl'))
  prediction = net3(x) 

  plt.subplot(133)
  plt.title('net3')
  plt.scatter(x.data.numpy(), y.data.numpy())
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 

# 运行测试
save()
reload_net()
reload_params() 

实验结果:

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持我们。

您可能感兴趣的文章:

  • PyTorch上搭建简单神经网络实现回归和分类的示例
  • PyTorch上实现卷积神经网络CNN的方法
(0)

相关推荐

  • PyTorch上实现卷积神经网络CNN的方法

    一.卷积神经网络 卷积神经网络(ConvolutionalNeuralNetwork,CNN)最初是为解决图像识别等问题设计的,CNN现在的应用已经不限于图像和视频,也可用于时间序列信号,比如音频信号和文本数据等.CNN作为一个深度学习架构被提出的最初诉求是降低对图像数据预处理的要求,避免复杂的特征工程.在卷积神经网络中,第一个卷积层会直接接受图像像素级的输入,每一层卷积(滤波器)都会提取数据中最有效的特征,这种方法可以提取到图像中最基础的特征,而后再进行组合和抽象形成更高阶的特征,因此CNN在

  • PyTorch上搭建简单神经网络实现回归和分类的示例

    本文介绍了PyTorch上搭建简单神经网络实现回归和分类的示例,分享给大家,具体如下: 一.PyTorch入门 1. 安装方法 登录PyTorch官网,http://pytorch.org,可以看到以下界面: 按上图的选项选择后即可得到Linux下conda指令: conda install pytorch torchvision -c soumith 目前PyTorch仅支持MacOS和Linux,暂不支持Windows.安装 PyTorch 会安装两个模块,一个是torch,一个 torch

  • PyTorch快速搭建神经网络及其保存提取方法详解

    有时候我们训练了一个模型, 希望保存它下次直接使用,不需要下次再花时间去训练 ,本节我们来讲解一下PyTorch快速搭建神经网络及其保存提取方法详解 一.PyTorch快速搭建神经网络方法 先看实验代码: import torch import torch.nn.functional as F # 方法1,通过定义一个Net类来建立神经网络 class Net(torch.nn.Module): def __init__(self, n_feature, n_hidden, n_output):

  • pytorch快速搭建神经网络_Sequential操作

    之前用Class类来搭建神经网络 class Neuro_net(torch.nn.Module): """神经网络""" def __init__(self, n_feature, n_hidden_layer, n_output): super(Neuro_net, self).__init__() self.hidden_layer = torch.nn.Linear(n_feature, n_hidden_layer) self.outp

  • Mac下快速搭建PHP开发环境步骤详解

    最近做了一个后端的项目,是用PHP+MySQL+Nginx做的,所以把搭建环境的方法简单总结一下. 备注: 物料:Apache/Nginx+PHP+MySQL+MAMPMac OS 10.12.1 自带Apache,Nginx和PHP 1.运行Apache 查看Apache版本,在终端根目录输入如下命令: sudo apachectl -v 终端会输出Apache的版本及built时间 Server version: Apache/2.4.23 (Unix) Server built:   Au

  • 快速搭建React的环境步骤详解

    前端生态这几年可谓迎来了大发展,在这个生态圈内,不接受新事物学习新技能,等于堕入魔道. 本文尝试对前端开发利器React,以及构建项目过程中涉及的技术栈进行介绍,以期开启整个构建流程上的思考. 有必要指出的是,要弄明白一件事情的原理,首先要知道它的目的是什么. 1.Nodejs & NPM 为什么要提nodejs呢? 与其说nodejs提供了服务端开发的另一种可能,不如说它彻底改变了整个前端开发的生态.nodejs平台上衍生出了强大的npm.grunt.express等,几乎重新定义了前端的工作

  • 对Pytorch神经网络初始化kaiming分布详解

    函数的增益值 torch.nn.init.calculate_gain(nonlinearity, param=None) 提供了对非线性函数增益值的计算. 增益值gain是一个比例值,来调控输入数量级和输出数量级之间的关系. fan_in和fan_out pytorch计算fan_in和fan_out的源码 def _calculate_fan_in_and_fan_out(tensor): dimensions = tensor.ndimension() if dimensions < 2:

  • Python深度学习pytorch神经网络图像卷积运算详解

    目录 互相关运算 卷积层 特征映射 由于卷积神经网络的设计是用于探索图像数据,本节我们将以图像为例. 互相关运算 严格来说,卷积层是个错误的叫法,因为它所表达的运算其实是互相关运算(cross-correlation),而不是卷积运算.在卷积层中,输入张量和核张量通过互相关运算产生输出张量. 首先,我们暂时忽略通道(第三维)这一情况,看看如何处理二维图像数据和隐藏表示.下图中,输入是高度为3.宽度为3的二维张量(即形状为 3 × 3 3\times3 3×3).卷积核的高度和宽度都是2. 注意,

  • Java 用Prometheus搭建实时监控系统过程详解

    上帝之火 本系列讲述的是开源实时监控告警解决方案Prometheus,这个单词很牛逼.每次我都能联想到带来上帝之火的希腊之神,普罗米修斯.而这个开源的logo也是火,个人挺喜欢这个logo的设计. 本系列着重介绍Prometheus以及如何用它和其周边的生态来搭建一套属于自己的实时监控告警平台. 本系列受众对象为初次接触Prometheus的用户,大神勿喷,偏重于操作和实战,但是重要的概念也会精炼出提及下.系列主要分为以下几块 Prometheus各个概念介绍和搭建,如何抓取数据(本次分享内容)

  • Python Flask 搭建微信小程序后台详解

    前言: 近期需要开发一个打分的微信小程序,涉及到与后台服务器的数据交互,因为业务逻辑相对简单,故选择Python的轻量化web框架Flask来搭建后台程序.因为是初次接触小程序,经过一番摸索和尝试,个人觉得的微信小程序与后台的交互有点像ajax,所以有ajax开发经验的同学开发小程序应该很容易上手,因为本文着重讲解后台程序的搭建,所以,微信小程序的前端开发将一笔带过,有兴趣学习小程序前端语言的同学可移步网易云课堂的一套快速入门课程<轻松玩转微信小程序>. 分三步讲解微信小程序与Python后台

  • pytorch对可变长度序列的处理方法详解

    主要是用函数torch.nn.utils.rnn.PackedSequence()和torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来进行的,分别来看看这三个函数的用法. 1.torch.nn.utils.rnn.PackedSequence() NOTE: 这个类的实例不能手动创建.它们只能被 pack_padded_sequence() 实例化. PackedSequence

  • pytorch的梯度计算以及backward方法详解

    基础知识 tensors: tensor在pytorch里面是一个n维数组.我们可以通过指定参数reuqires_grad=True来建立一个反向传播图,从而能够计算梯度.在pytorch中一般叫做dynamic computation graph(DCG)--即动态计算图. import torch import numpy as np # 方式一 x = torch.randn(2,2, requires_grad=True) # 方式二 x = torch.autograd.Variabl

随机推荐