Pytorch根据layers的name冻结训练方式
使用model.named_parameters()可以轻松搞定,
model.cuda() # ######################################## Froze some layers to fine-turn the model ######################## for name, param in model.named_parameters(): # 带有参数名的模型的各个层包含的参数遍历 if 'out' or 'merge' or 'before_regress' in name: # 判断参数名字符串中是否包含某些关键字 continue param.requires_grad = False # ############################################################################################################# optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=opt.learning_rate * args.world_size, momentum=0.9, weight_decay=5e-4)
以上这篇Pytorch根据layers的name冻结训练方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。
相关推荐
-
pytorch使用指定GPU训练的实例
本文适合多GPU的机器,并且每个用户需要单独使用GPU训练. 虽然pytorch提供了指定gpu的几种方式,但是使用不当的话会遇到out of memory的问题,主要是因为pytorch会在第0块gpu上初始化,并且会占用一定空间的显存.这种情况下,经常会出现指定的gpu明明是空闲的,但是因为第0块gpu被占满而无法运行,一直报out of memory错误. 解决方案如下: 指定环境变量,屏蔽第0块gpu CUDA_VISIBLE_DEVICES = 1 main.py 这句话表示只有第1块
-
pytorch 指定gpu训练与多gpu并行训练示例
一. 指定一个gpu训练的两种方法: 1.代码中指定 import torch torch.cuda.set_device(id) 2.终端中指定 CUDA_VISIBLE_DEVICES=1 python 你的程序 其中id就是你的gpu编号 二. 多gpu并行训练: torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0) 该函数实现了在module级别上的数据并行使用,注意batch size要大于G
-
Pytorch根据layers的name冻结训练方式
使用model.named_parameters()可以轻松搞定, model.cuda() # ######################################## Froze some layers to fine-turn the model ######################## for name, param in model.named_parameters(): # 带有参数名的模型的各个层包含的参数遍历 if 'out' or 'merge' or 'bef
-
Pytorch 数据加载与数据预处理方式
数据加载分为加载torchvision.datasets中的数据集以及加载自己使用的数据集两种情况. torchvision.datasets中的数据集 torchvision.datasets中自带MNIST,Imagenet-12,CIFAR等数据集,所有的数据集都是torch.utils.data.Dataset的子类,都包含 _ _ len _ (获取数据集长度)和 _ getItem _ _ (获取数据集中每一项)两个子方法. Dataset源码如上,可以看到其中包含了两个没有实现的子
-
pytorch finetuning 自己的图片进行训练操作
一.pytorch finetuning 自己的图片进行训练 这种读取图片的方式用的是torch自带的 ImageFolder,读取的文件夹必须在一个大的子文件下,按类别归好类. 就像我现在要区分三个类别. #perpare data set #train data train_data=torchvision.datasets.ImageFolder('F:/eyeDataSet/trainData',transform=transforms.Compose( [ transforms.Sca
-
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
实践中,针对不同的任务需求,我们经常会在现成的网络结构上做一定的修改来实现特定的目的. 假如我们现在有一个简单的两层感知机网络: # -*- coding: utf-8 -*- import torch from torch.autograd import Variable import torch.optim as optim x = Variable(torch.FloatTensor([1, 2, 3])).cuda() y = Variable(torch.FloatTensor([4,
-
在Pytorch中计算自己模型的FLOPs方式
https://github.com/Lyken17/pytorch-OpCounter 安装方法很简单: pip install thop 基本用法: from torchvision.models import resnet50from thop import profile model = resnet50() flops, params = profile(model, input_size=(1, 3, 224,224)) 对自己的module进行特别的计算: class YourMo
-
pytorch中tensor张量数据类型的转化方式
1.tensor张量与numpy相互转换 tensor ----->numpy import torch a=torch.ones([2,5]) tensor([[1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.]]) # ********************************** b=a.numpy() array([[1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.]], dtype=float32) numpy --
-
pytorch实现focal loss的两种方式小结
我就废话不多说了,直接上代码吧! import torch import torch.nn.functional as F import numpy as np from torch.autograd import Variable ''' pytorch实现focal loss的两种方式(现在讨论的是基于分割任务) 在计算损失函数的过程中考虑到类别不平衡的问题,假设加上背景类别共有6个类别 ''' def compute_class_weights(histogram): classWeigh
-
pytorch载入预训练模型后,实现训练指定层
1.有了已经训练好的模型参数,对这个模型的某些层做了改变,如何利用这些训练好的模型参数继续训练: pretrained_params = torch.load('Pretrained_Model') model = The_New_Model(xxx) model.load_state_dict(pretrained_params.state_dict(), strict=False) strict=False 使得预训练模型参数中和新模型对应上的参数会被载入,对应不上或没有的参数被抛弃. 2.
-
PyTorch使用cpu加载模型运算方式
没gpu没cuda支持的时候加载模型到cpu上计算 将 model = torch.load(path, map_location=lambda storage, loc: storage.cuda(device)) 改为 model = torch.load(path, map_location='cpu') 然后删掉所有变量后面的.cuda()方法 以上这篇PyTorch使用cpu加载模型运算方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们.
-
keras多显卡训练方式
使用keras进行训练,默认使用单显卡,即使设置了os.environ['CUDA_VISIBLE_DEVICES']为两张显卡,也只是占满了显存,再设置tf.GPUOptions(allow_growth=True)之后可以清楚看到,只占用了第一张显卡,第二张显卡完全没用. 要使用多张显卡,需要按如下步骤: (1)import multi_gpu_model函数:from keras.utils import multi_gpu_model (2)在定义好model之后,使用multi_gpu
随机推荐
- Python 列表(List)操作方法详解
- 一个不错的js html页面倒计时可精确到秒
- Angular.Js的自动化测试详解
- xshell上传下载文件(Windows、Linux)
- 数据库中排序的对比及使用条件详解
- JavaScript在Android的WebView中parseInt函数转换不正确问题解决方法
- 解析如何使用Zend Framework 连接数据库
- Java正则表达式使用
- win10下Python3.6安装、配置以及pip安装包教程
- python字符串连接的N种方式总结
- Docker 镜像、容器、仓库的概念及应用详解
- Python可跨平台实现获取按键的方法
- Bootstrap 组件之按钮(二)
- asp+mysql+utf8 网页出现乱码问题的解决方法
- java实现单链表之逆序
- 判断及设置浏览器全屏模式
- JavaScript的instanceof运算符学习教程
- MyEclipse去除网上复制下来的代码带有的行号(正则去除行号)
- CentOS 6.3安装配置Weblogic-10方法
- 兼容firefox,chrome的网页灰度效果