Pytorch之finetune使用详解
finetune分为全局finetune和局部finetune。首先介绍一下局部finetune步骤:
1.固定参数
for name, child in model.named_children(): for param in child.parameters(): param.requires_grad = False
后,只传入 需要反传的参数,否则会报错
filter(lambda param: param.requires_grad, model.parameters())
2.调低学习率,加快衰减
finetune是在预训练模型上进行微调,学习速率不能太大。
目前不清楚:学习速率降低的幅度可以更快一些。这样以来,在使用step的策略时,stepsize可以更小一些。
直接从原始数据训练的base_lr一般为0.01,微调要比0.01小,置为0.001
要比直接训练的小一些,直接训练的stepsize为100000,finetune的stepsize: 50000
3. 固定bn或取消dropout:
batchnorm会影响训练的效果,随着每个batch,追踪样本的均值和方差。对于固定的网络,bn应该使用全局的数值
def freeze_bn(self): for layer in self.modules(): if isinstance(layer, nn.BatchNorm2d): layer.eval()
训练时,model.train()会修改模式,freeze_zn()应该在这里后面
4.过滤参数
训练时,对于优化器,应该只传入需要改变的参数,否则会报错
filter(lambda p: p.requires_grad, model.parameters())
以上这篇Pytorch之finetune使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。
相关推荐
-
python PyTorch参数初始化和Finetune
前言 这篇文章算是论坛PyTorch Forums关于参数初始化和finetune的总结,也是我在写代码中用的算是"最佳实践"吧.最后希望大家没事多逛逛论坛,有很多高质量的回答. 参数初始化 参数的初始化其实就是对参数赋值.而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了.这就是PyTorch简洁高效所在. 所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法
-
Pytorch之finetune使用详解
finetune分为全局finetune和局部finetune.首先介绍一下局部finetune步骤: 1.固定参数 for name, child in model.named_children(): for param in child.parameters(): param.requires_grad = False 后,只传入 需要反传的参数,否则会报错 filter(lambda param: param.requires_grad, model.parameters()) 2.调低学
-
pytorch AvgPool2d函数使用详解
我就废话不多说了,直接上代码吧! import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable import numpy as np input = Variable(torch.Tensor([[[1, 3, 3, 4, 5, 6, 7], [1, 2, 3, 4, 5, 6, 7]], [[1, 3, 3, 4, 5, 6, 7], [1, 2, 3
-
pytorch之ImageFolder使用详解
pytorch之ImageFolder torchvision已经预先实现了常用的Dataset,包括前面使用过的CIFAR-10,以及ImageNet.COCO.MNIST.LSUN等数据集,可通过诸如torchvision.datasets.CIFAR10来调用.在这里介绍一个会经常使用到的Dataset--ImageFolder. ImageFolder假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名,其构造函数如下: ImageFolder(root, tra
-
Anaconda+vscode+pytorch环境搭建过程详解
1.安装Anaconda Anaconda指的是一个开源的Python发行版本,其包含了conda.Python等180多个科学包及其依赖项.在官网上下载https://www.anaconda.com/distribution/,因为服务器在国外会很慢,建议从清华镜像https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/下载. 2.安装VScode 需要在Anaconda再装VScode,因为Anaconda公司和微软公司的合作,不用在对进
-
pytorch 限制GPU使用效率详解(计算效率)
问题 用过 tensorflow 的人都知道, tf 可以限制程序在 GPU 中的使用效率,但 pytorch 中没有这个操作. 思路 于是我想到了一个代替方法,玩过单片机点灯的同学都知道,灯的亮度是靠占空比实现的,这实际上也是计算机的运行原理. 那我们是不是也可以通过增加 GPU 不工作的时间,进而降低 GPU 的使用效率 ? 主要代码 import time ... rest_time = 0.15 ... for _ in range( XXX ): ... outputs = all_G
-
Pytorch自动求导函数详解流程以及与TensorFlow搭建网络的对比
一.定义新的自动求导函数 在底层,每个原始的自动求导运算实际上是两个在Tensor上运行的函数.其中,forward函数计算从输入Tensor获得的输出Tensors.而backward函数接收输出,Tensors对于某个标量值得梯度,并且计算输入Tensors相对于该相同标量值得梯度. 在Pytorch中,可以容易地通过定义torch.autograd.Function的子类实现forward和backward函数,来定义自动求导函数.之后就可以使用这个新的自动梯度运算符了.我们可以通过构造一
-
Python人工智能学习PyTorch实现WGAN示例详解
目录 1.GAN简述 2.生成器模块 3.判别器模块 4.数据生成模块 5.判别器训练 6.生成器训练 7.结果可视化 1.GAN简述 在GAN中,有两个模型,一个是生成模型,用于生成样本,一个是判别模型,用于判断样本是真还是假.但由于在GAN中,使用的JS散度去计算损失值,很容易导致梯度弥散的情况,从而无法进行梯度下降更新参数,于是在WGAN中,引入了Wasserstein Distance,使得训练变得稳定.本文中我们以服从高斯分布的数据作为样本. 2.生成器模块 这里从2维数据,最终生成2
-
人工智能学习Pytorch张量数据类型示例详解
目录 1.python 和 pytorch的数据类型区别 2.张量 ①一维张量 ②二维张量 ③3维张量 ④4维张量 1.python 和 pytorch的数据类型区别 在PyTorch中无法展示字符串,因此表达字符串,需要将其转换成编码的类型,比如one_hot,word2vec等. 2.张量 在python中,会有标量,向量,矩阵等的区分.但在PyTorch中,这些统称为张量tensor,只是维度不同而已. 标量就是0维张量,只有一个数字,没有维度. 向量就是1维张量,是有顺序的数字,但没有"
-
pytorch中使用LSTM详解
目录 LSMT层 1.__init__方法 2.forward方法的输入 3.forward方法的输出 LSTMCell LSMT层 可以在troch.nn模块中找到LSTM类 lstm = torch.nn.LSTM(*paramsters) 1.__init__方法 首先对nn.LSTM类进行实例化,需要传入的参数如下图所示: 一般我们关注这4个: input_size表示输入的每个token的维度,也可以理解为一个word的embedding的维度. hidden_size表示隐藏层也就是
-
如何使用Pytorch完成图像分类任务详解
目录 概述: 一. 数据准备 二.定义一个卷积神经网络 三.完整代码如下: 总结 概述: 本文将通过组织自己的训练数据,使用Pytorch深度学习框架来训练自己的模型,最终实现自己的图像分类!本篇文章以识别阳台为例子,进行讲述. 一. 数据准备 深度学习的基础就是数据,完成图像分类,当然数据也必不可少.先使用爬虫爬取阳台图片1200张以及非阳台图片1200张,图片的名字从0.jpg一直编到2400.jpg,把爬取的图片放置在同一个文件夹中命名为image(如下图1所示). 图1 针对百度图片的爬
随机推荐
- ajax异步刷新实现更新数据库
- Angular和百度地图的结合实例代码
- Swift心得笔记之集合类型
- js获取ajax返回值代码
- MVC4 网站发布(整理+部分问题收集和解决方案)
- asp.net coolite 删除时弹出确定按钮
- PHP判断远程图片是否存在的几种方法
- PHP中实现接收多个name相同但Value不相同表单数据实例
- php生成二维码
- python通过imaplib模块读取gmail里邮件的方法
- python制作小说爬虫实录
- Python中time模块与datetime模块在使用中的不同之处
- php格式化电话号码的方法
- JS实现保留n位小数的四舍五入问题示例
- Nodejs中Express 常用中间件 body-parser 实现解析
- 常用类之TCP连接类-socket编程
- C#保存与读取DataTable信息到XML格式的方法
- jQuery参数列表集合
- Javascript设计模式之装饰者模式详解篇
- JavaScript字符串删除重复字符的方法