PyTorch实现线性回归详细过程

目录
  • 一、实现步骤
    • 1、准备数据
    • 2、设计模型
    • 3、构造损失函数和优化器
    • 4、训练过程
    • 5、结果展示
  • 二、参考文献

一、实现步骤

1、准备数据

x_data = torch.tensor([[1.0],[2.0],[3.0]])
y_data = torch.tensor([[2.0],[4.0],[6.0]])

2、设计模型

class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel,self).__init__()
        self.linear = torch.nn.Linear(1,1)
        
    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred
        
model = LinearModel()  

3、构造损失函数和优化器

criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)

4、训练过程

epoch_list = []
loss_list = []
w_list = []
b_list = []
for epoch in range(1000):
    y_pred = model(x_data)                      # 计算预测值
    loss = criterion(y_pred, y_data)    # 计算损失
    print(epoch,loss)
    
    epoch_list.append(epoch)
    loss_list.append(loss.data.item())
    w_list.append(model.linear.weight.item())
    b_list.append(model.linear.bias.item())
    
    optimizer.zero_grad()   # 梯度归零
    loss.backward()         # 反向传播
    optimizer.step()        # 更新

5、结果展示

展示最终的权重和偏置:

# 输出权重和偏置
print('w = ',model.linear.weight.item())
print('b = ',model.linear.bias.item())

结果为:

w =  1.9998501539230347
b =  0.0003405189490877092

模型测试:

# 测试模型
x_test = torch.tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ',y_test.data)

y_pred =  tensor([[7.9997]])

分别绘制损失值随迭代次数变化的二维曲线图和其随权重与偏置变化的三维散点图:

# 二维曲线图
plt.plot(epoch_list,loss_list,'b')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()

# 三维散点图
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(w_list,b_list,loss_list,c='r')
#设置坐标轴
ax.set_xlabel('weight')
ax.set_ylabel('bias')
ax.set_zlabel('loss')
plt.show()

结果如下图所示:

到此这篇关于PyTorch实现线性回归详细过程的文章就介绍到这了,更多相关PyTorch线性回归内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

二、参考文献

  • [1] https://www.bilibili.com/video/BV1Y7411d7Ys?p=5
(0)

相关推荐

  • 使用pytorch实现线性回归

    本文实例为大家分享了pytorch实现线性回归的具体代码,供大家参考,具体内容如下 线性回归都是包括以下几个步骤:定义模型.选择损失函数.选择优化函数. 训练数据.测试 import torch import matplotlib.pyplot as plt # 构建数据集 x_data= torch.Tensor([[1.0],[2.0],[3.0],[4.0],[5.0],[6.0]]) y_data= torch.Tensor([[2.0],[4.0],[6.0],[8.0],[10.0]

  • pytorch实现线性回归以及多元回归

    本文实例为大家分享了pytorch实现线性回归以及多元回归的具体代码,供大家参考,具体内容如下 最近在学习pytorch,现在把学习的代码放在这里,下面是github链接 直接附上github代码 # 实现一个线性回归 # 所有的层结构和损失函数都来自于 torch.nn # torch.optim 是一个实现各种优化算法的包,调用的时候必须是需要优化的参数传入,这些参数都必须是Variable x_train = np.array([[3.3],[4.4],[5.5],[6.71],[6.93

  • pytorch使用Variable实现线性回归

    本文实例为大家分享了pytorch使用Variable实现线性回归的具体代码,供大家参考,具体内容如下 一.手动计算梯度实现线性回归 #导入相关包 import torch as t import matplotlib.pyplot as plt #构造数据 def get_fake_data(batch_size = 8): #设置随机种子数,这样每次生成的随机数都是一样的 t.manual_seed(10) #产生随机数据:y = 2*x+3,加上了一些噪声 x = t.rand(batch

  • pytorch实现线性回归

    pytorch实现线性回归代码练习实例,供大家参考,具体内容如下 欢迎大家指正,希望可以通过小的练习提升对于pytorch的掌握 # 随机初始化一个二维数据集,使用朋友torch训练一个回归模型 import numpy as np import random import matplotlib.pyplot as plt x = np.arange(20) y = np.array([5*x[i] + random.randint(1,20) for i in range(len(x))])

  • 利用Pytorch实现简单的线性回归算法

    最近听了张江老师的深度学习课程,用Pytorch实现神经网络预测,之前做Titanic生存率预测的时候稍微了解过Tensorflow,听说Tensorflow能做的Pyorch都可以做,而且更方便快捷,自己尝试了一下代码的逻辑确实比较简单. Pytorch涉及的基本数据类型是tensor(张量)和Autograd(自动微分变量),对于这些概念我也是一知半解,tensor和向量,矩阵等概念都有交叉的部分,下次有时间好好补一下数学的基础知识,不过现阶段的任务主要是应用,学习掌握思维和方法即可,就不再

  • PyTorch实现线性回归详细过程

    目录 一.实现步骤 1.准备数据 2.设计模型 3.构造损失函数和优化器 4.训练过程 5.结果展示 二.参考文献 一.实现步骤 1.准备数据 x_data = torch.tensor([[1.0],[2.0],[3.0]]) y_data = torch.tensor([[2.0],[4.0],[6.0]]) 2.设计模型 class LinearModel(torch.nn.Module):     def __init__(self):         super(LinearModel

  • Windows10+anacond+GPU+pytorch安装详细过程

    1.查看自己电脑是否匹配GPU版本. 设备管理器查看. 查看官网是否匹配.地址:https://developer.nvidia.com/cuda-gpus  ** 2.进入NVIDIA对电脑版本进行查**看. 如果可以的的话可以自己卸载原来版本,后安装新版本.安装地址https://developer.nvidia.com/cuda-toolkit-archive 接下来,进入NVIDIA安装过程,在这安装过程中,我一开始直接选择的精简安装,但由于VS的原因,导致无法正常安装,于是我换成了自定

  • 详解mongoDB主从复制搭建详细过程

    详解mongoDB主从复制搭建详细过程 实验目的搭建mongoDB主从复制 主 192.168.0.4 从 192.168.0.7 mongodb的安装 1: 下载mongodb www.mongodb.org 下载最新的stable版 查看自己服务器 适合哪个种方式下载(wget 不可以的 可以用下面方式下载) wget https://fastdl.mongodb.org/linux/mongodb-linux-x86_64-rhel62-3.0.5.tgz curl -O -L https

  • 浅谈vue-lazyload实现的详细过程

    本文介绍了浅谈vue-lazyload实现的详细过程,分享给大家,也给自己留个笔记 首先 ,在命令行输入npm install vue-lazyload&&cnpm install vue-lazyload 然后,在main.js里引入这个模块. import 'VueLazyload' from 'vue-lazyload' Vue.use(VueLazyload,{ preload:1.3,//预加载的宽高 loading:"img的加载中的显示的图片的路径", e

  • windows 2008r2+php5.6.28环境搭建详细过程

    安装IIS7 1.打开服务器管理器(开始-计算机-右键-管理-也可以打开),添加角色 直接下一步 勾选Web服务器(IIS),下一步,有个注意事项继续下一步(这里我就不截图了) 勾选ASP.NET会弹出以下窗口添加所需的角色服务,勾选CGI(这里根据个人情况勾选,CGI是必选的,否则PHP不生效的) 然后直接下一步安装即可,需要等待一小会! 此时已安装成功,关闭即可,打开IIS管理器,如下图 将原来的网站删除,添加新网站 网站名称随便起,物理路径即表示你的根路径,我在D盘建立个www文件夹作为根

  • centos 6.9 升级glibc动态库的详细过程

    glibc是gnu发布的libc库,即c运行库,glibc是linux系统中最底层的api,几乎其它任何运行库都会依赖于glibc.glibc除了封装linux操作系统所提供的系统服务外,它本身也提供了许多其它一些必要功能服务的实现.很多linux的基本命令,比如ls,mv,cp, rm, ll,ln等,都得依赖于它,如果操作错误或者升级失败会导致系统命令不能使用,严重的造成系统退出后无法重新进入,所以操作时候需要慎重,升级之前保存好重要资料. 写这篇笔记的目的其实是我在centos 下想要安装

  • linux 之centos7搭建mysql5.7.29的详细过程

    1.下载mysql 1.1下载地址 https://downloads.mysql.com/archives/community/ 1.2版本选择 2.管理组及目录权限 2.1解压mysql tar -zxf mysql-5.7.29-linux-glibc2.12-x86_64.tar.gz上传目录/home/tools 2.2重命名 mv mysql-5.7.29-linux-glibc2.12-x86_64 mysql-5.7.29 2.3移动指定目录 mv mysql-5.7.29 /u

  • Ubuntu19.10开启ssh服务(详细过程)

    Ubuntu开启个ssh竟然花了我一个多小时,主要是一开始看的教程步骤不详细,然后我开启的是一个一万多的主机,开关机都挺慢的,在这里记录下详细步骤,方便自己以后查看 第一步,查看ssh是否已经开启 sudo ps -e | grep ssh 如果最后返回是sshd,证明ssh已经开启,跳到第四步 第二步,如果没有显示,试着开启ssh服务 sudo /etc/init.d/ssh start 如果返回的是命令未找到,证明未安装ssh服务 第三步,安装openssh服务查看服务有没有开启 sudo

  • springboot+idea+maven 多模块项目搭建的详细过程(连接数据库进行测试)

    创建之前项目之前 记得改一下 maven  提高下载Pom速度 记得 setting 中要改 maven  改成 阿里云的.具体方法 网上查第一步 搭建parents 项目,为maven项目 ,不为springboot 项目 记得修改groupId 第二步 搭建多个子模块, honor-dao   honor-manager   honor-common记得创建 honor-manager 的时候 要把他的gruopId 改成com.honor.manager 这里爆红的原因是 因为 我做到后面

  • Nginx下配置Https证书详细过程

    一.Http与Https的区别 HTTP:是互联网上应用最为广泛的一种网络协议,是一个客户端和服务器端请求和应答的标准(TCP),用于从WWW服务器传输超文本到本地浏览器的传输协议,它可以使浏览器更加高效,使网络传输减少. HTTPS:是以安全为目标的HTTP通道,简单讲是HTTP的安全版,即HTTP下加入SSL层,HTTPS的安全基础是SSL,因此加密的详细内容就需要SSL.HTTPS协议的主要作用可以分为两种:一种是建立一个信息安全通道,来保证数据传输的安全:另一种就是确认网站的真实性. H

随机推荐