PyTorch学习笔记之回归实战

本文主要是用PyTorch来实现一个简单的回归任务。

编辑器:spyder

1.引入相应的包及生成伪数据

import torch
import torch.nn.functional as F # 主要实现激活函数
import matplotlib.pyplot as plt # 绘图的工具
from torch.autograd import Variable

# 生成伪数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim = 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())

# 变为Variable
x, y = Variable(x), Variable(y)

其中torch.linspace是为了生成连续间断的数据,第一个参数表示起点,第二个参数表示终点,第三个参数表示将这个区间分成平均几份,即生成几个数据。因为torch只能处理二维的数据,所以我们用torch.unsqueeze给伪数据添加一个维度,dim表示添加在第几维。torch.rand返回的是[0,1)之间的均匀分布。

2.绘制数据图像

在上述代码后面加下面的代码,然后运行可得伪数据的图形化表示:

# 绘制数据图像
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

3.建立神经网络

class Net(torch.nn.Module):
 def __init__(self, n_feature, n_hidden, n_output):
  super(Net, self).__init__()
  self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
  self.predict = torch.nn.Linear(n_hidden, n_output) # output layer

 def forward(self, x):
  x = F.relu(self.hidden(x))  # activation function for hidden layer
  x = self.predict(x)    # linear output
  return x

net = Net(n_feature=1, n_hidden=10, n_output=1)  # define the network
print(net) # net architecture

一般神经网络的类都继承自torch.nn.Module__init__()和forward()两个函数是自定义类的主要函数。在__init__()中都要添加一句super(Net, self).__init__(),这是固定的标准写法,用于继承父类的初始化函数。__init__()中只是对神经网络的模块进行了声明,真正的搭建是在forwad()中实现。自定义类中的成员都通过self指针来进行访问,所以参数列表中都包含了self。

如果想查看网络结构,可以用print()函数直接打印网络。本文的网络结构输出如下:

Net (
 (hidden): Linear (1 -> 10)
 (predict): Linear (10 -> 1)
)

4.训练网络

# 训练100次
for t in range(100):
 prediction = net(x)  # input x and predict based on x

 loss = loss_func(prediction, y)  # 一定要是输出在前,标签在后 (1. nn output, 2. target)

 optimizer.zero_grad() # clear gradients for next train
 loss.backward()   # backpropagation, compute gradients
 optimizer.step()  # apply gradients

训练网络之前我们需要先定义优化器和损失函数。torch.optim包中包括了各种优化器,这里我们选用最常见的SGD作为优化器。因为我们要对网络的参数进行优化,所以我们要把网络的参数net.parameters()传入优化器中,并设置学习率(一般小于1)。

由于这里是回归任务,我们选择torch.nn.MSELoss()作为损失函数。

由于优化器是基于梯度来优化参数的,并且梯度会保存在其中。所以在每次优化前要通过optimizer.zero_grad()把梯度置零,然后再后向传播及更新。

5.可视化训练过程

plt.ion() # something about plotting

for t in range(100):
 ...

 if t % 5 == 0:
  # plot and show learning process
  plt.cla()
  plt.scatter(x.data.numpy(), y.data.numpy())
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
  plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'})
  plt.pause(0.1)

plt.ioff()
plt.show()

6.运行结果

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持我们。

(0)

相关推荐

  • 详解Pytorch 使用Pytorch拟合多项式(多项式回归)

    使用Pytorch来编写神经网络具有很多优势,比起Tensorflow,我认为Pytorch更加简单,结构更加清晰. 希望通过实战几个Pytorch的例子,让大家熟悉Pytorch的使用方法,包括数据集创建,各种网络层结构的定义,以及前向传播与权重更新方式. 比如这里给出 很显然,这里我们只需要假定 这里我们只需要设置一个合适尺寸的全连接网络,根据不断迭代,求出最接近的参数即可. 但是这里需要思考一个问题,使用全连接网络结构是毫无疑问的,但是我们的输入与输出格式是什么样的呢? 只将一个x作为输入

  • PyTorch线性回归和逻辑回归实战示例

    线性回归实战 使用PyTorch定义线性回归模型一般分以下几步: 1.设计网络架构 2.构建损失函数(loss)和优化器(optimizer) 3.训练(包括前馈(forward).反向传播(backward).更新模型参数(update)) #author:yuquanle #data:2018.2.5 #Study of LinearRegression use PyTorch import torch from torch.autograd import Variable # train

  • PyTorch上搭建简单神经网络实现回归和分类的示例

    本文介绍了PyTorch上搭建简单神经网络实现回归和分类的示例,分享给大家,具体如下: 一.PyTorch入门 1. 安装方法 登录PyTorch官网,http://pytorch.org,可以看到以下界面: 按上图的选项选择后即可得到Linux下conda指令: conda install pytorch torchvision -c soumith 目前PyTorch仅支持MacOS和Linux,暂不支持Windows.安装 PyTorch 会安装两个模块,一个是torch,一个 torch

  • PyTorch学习笔记之回归实战

    本文主要是用PyTorch来实现一个简单的回归任务. 编辑器:spyder 1.引入相应的包及生成伪数据 import torch import torch.nn.functional as F # 主要实现激活函数 import matplotlib.pyplot as plt # 绘图的工具 from torch.autograd import Variable # 生成伪数据 x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim = 1)

  • Pytorch学习笔记DCGAN极简入门教程

    目录 1.图片分类网络 2.图片生成网络 首先是图片分类网络: 重点是生成网络 每一个step分为三个步骤: 1.图片分类网络 这是一个二分类网络,可以是alxnet ,vgg,resnet任何一个,负责对图片进行二分类,区分图片是真实图片还是生成的图片 2.图片生成网络 输入是一个随机噪声,输出是一张图片,使用的是反卷积层 相信学过深度学习的都能写出这两个网络,当然如果你写不出来,没关系,有人替你写好了 首先是图片分类网络: 简单来说就是cnn+relu+sogmid,可以换成任何一个分类网络

  • Java中jqGrid 学习笔记整理——进阶篇(二)

    相关阅读: Java中jqGrid 学习笔记整理--进阶篇(一) 本篇开始正式与后台(java语言)进行数据交互,使用的平台为 JDK:java 1.8.0_71 myEclisp 2015 Stable 2.0 Apache Tomcat-8.0.30 Mysql 5.7 Navicat for mysql 11.2.5(mysql数据库管理工具) 一.数据库部分 1.创建数据库 使用Navicat for mysql创建数据库(使用其他工具或直接使用命令行暂不介绍) 2. 2.创建表 双击打

  • Spring学习笔记3之消息队列(rabbitmq)发送邮件功能

    rabbitmq简介: MQ全称为Message Queue, 消息队列(MQ)是一种应用程序对应用程序的通信方法.应用程序通过读写出入队列的消息(针对应用程序的数据)来通信,而无需专用连接来链接它们.消息传递指的是程序之间通过在消息中发送数据进行通信,而不是通过直接调用彼此来通信,直接调用通常是用于诸如远程过程调用的技术.排队指的是应用程序通过 队列来通信.队列的使用除去了接收和发送应用程序同时执行的要求.其中较为成熟的MQ产品有IBM WEBSPHERE MQ. 本节的内容是用户注册时,将邮

  • Bootstrap学习笔记之css样式设计(2)

    首先,很感谢各位朋友对我的支持,关于bootstrap的学习总结,我会持续更新,如果有写的不对的地方,麻烦各位给我指正出来哈.关于上篇文章,固定布局和流式布局很关键,如果还不太清楚的可以再看看我写的:Bootstrap学习笔记之css样式设计(1) 这次我们来看看bootstrap中关于样式的一些具体关键的类以及如何使用这些类,类与类之间的区别,另外涉及到的一些相关类,举列子的时候解释. 一.表单 1.form-control类:含有此类的<input><select><te

  • Spring学习笔记1之IOC详解尽量使用注解以及java代码

    在实战中学习Spring,本系列的最终目的是完成一个实现用户注册登录功能的项目. 预想的基本流程如下: 1.用户网站注册,填写用户名.密码.email.手机号信息,后台存入数据库后返回ok.(学习IOC,mybatis,SpringMVC的基础知识,表单数据验证,文件上传等) 2.服务器异步发送邮件给注册用户.(学习消息队列) 3.用户登录.(学习缓存.Spring Security) 4.其他. 边学习边总结,不定时更新.项目环境为Intellij + Spring4. 一.准备工作. 1.m

  • Spring学习笔记2之表单数据验证、文件上传实例代码

    在上篇文章给大家介绍了Spring学习笔记1之IOC详解尽量使用注解以及java代码,接下来本文重点给大家介绍Spring学习笔记2之表单数据验证.文件上传实例代码,具体内容,请参考本文吧! 一.表单数据验证 用户注册时,需要填写账号.密码.邮箱以及手机号,均为必填项,并且需要符合一定的格式.比如账号需要32位以内,邮箱必须符合邮箱格式,手机号必须为11位号码等.可以采用在注册时验证信息,或者专门写一个工具类用来验证:来看下在SpringMVC中如何通过简单的注释实现表单数据验证. 在javax

  • angularjs学习笔记之简单介绍

    一.angularjs简介 AngularJS 是一个为动态WEB应用设计的结构框架.它能让你使用HTML作为模板语言,通过扩展HTML的语法,让你能更清楚.简洁地构建你的应用组件.它的创新点在于,利用 数据绑定 和 依赖注入,它使你不用再写大量的代码了.这些全都是通过浏览器端的Javascript实现,这也使得它能够完美地和任何服务器端技术结合. 说了这么多,估计你啥都没有理解...对吗?别着急,我来说说他的几个特点吧:模块化,数据双向绑定,依赖注入,指令.下面我们就跟着这几个特点进行学习.

  • AngularJS学习笔记之表单验证功能实例详解

    本文实例讲述了AngularJS学习笔记之表单验证功能.分享给大家供大家参考,具体如下: 一.执行基本的表单验证 <!DOCTYPE html> <html ng-app='exampleApp'> <head> <meta charset="UTF-8"> <title>表单</title> <script src="../../js/angular.min.js" type="

随机推荐