pytorch 预训练层的使用方法
pytorch 预训练层的使用方法
将其他地方训练好的网络,用到新的网络里面
加载预训练网络
1.原先已经训练好一个网络 AutoEncoder_FC()
2.首先加载该网络,读取其存储的参数
3.设置一个参数集
cnnpre = AutoEncoder_FC() cnnpre.load_state_dict(torch.load('autoencoder_FC.pkl')['state_dict']) cnnpre_dict =cnnpre.state_dict()
加载新网络
1.设置新的网络
2.设置新网络参数集
cnn= AutoEncoder() cnn_dict = cnn.state_dict()
更新新网络参数
1.将两个参数集比对,存在的网络参数保留
2.使用保留下的参数更新新网络参数集
3.加载新网络参数集到新网络中
cnnpre_dict = {k: v for k, v in cnnpre_dict.items() if k in cnn_dict} cnn_dict.update(cnnpre_dict) cnn.load_state_dict(cnn_dict)
以上这篇pytorch 预训练层的使用方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。
相关推荐
-
pytorch 固定部分参数训练的方法
需要自己过滤 optimizer.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3) 另外,如果是Variable,则可以初始化时指定 j = Variable(torch.randn(5,5), requires_grad=True) 但是如果是 m = nn.Linear(10,10) 是没有requires_grad传入的 m.requires_grad也没有 需要 for i in m.parameter
-
详解PyTorch批训练及优化器比较
一.PyTorch批训练 1. 概述 PyTorch提供了一种将数据包装起来进行批训练的工具--DataLoader.使用的时候,只需要将我们的数据首先转换为torch的tensor形式,再转换成torch可以识别的Dataset格式,然后将Dataset放入DataLoader中就可以啦. import torch import torch.utils.data as Data torch.manual_seed(1) # 设定随机数种子 BATCH_SIZE = 5 x = torch.li
-
Pytorch加载部分预训练模型的参数实例
前言 自从从深度学习框架caffe转到Pytorch之后,感觉Pytorch的优点妙不可言,各种设计简洁,方便研究网络结构修改,容易上手,比TensorFlow的臃肿好多了.对于深度学习的初学者,Pytorch值得推荐.今天主要主要谈谈Pytorch是如何加载预训练模型的参数以及代码的实现过程. 直接加载预选脸模型 如果我们使用的模型和预训练模型完全一样,那么我们就可以直接加载别人的模型,还有一种情况,我们在训练自己模型的过程中,突然中断了,但只要我们保存了之前的模型的参数也可以使用下面的代码直
-
pytorch 更改预训练模型网络结构的方法
一个继承nn.module的model它包含一个叫做children()的函数,这个函数可以用来提取出model每一层的网络结构,在此基础上进行修改即可,修改方法如下(去除后两层): resnet_layer = nn.Sequential(*list(model.children())[:-2]) 那么,接下来就可以构建我们的网络了: class Net(nn.Module): def __init__(self , model): super(Net, self).__init__() #取
-
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
实践中,针对不同的任务需求,我们经常会在现成的网络结构上做一定的修改来实现特定的目的. 假如我们现在有一个简单的两层感知机网络: # -*- coding: utf-8 -*- import torch from torch.autograd import Variable import torch.optim as optim x = Variable(torch.FloatTensor([1, 2, 3])).cuda() y = Variable(torch.FloatTensor([4,
-
python PyTorch预训练示例
前言 最近使用PyTorch感觉妙不可言,有种当初使用Keras的快感,而且速度还不慢.各种设计直接简洁,方便研究,比tensorflow的臃肿好多了.今天让我们来谈谈PyTorch的预训练,主要是自己写代码的经验以及论坛PyTorch Forums上的一些回答的总结整理. 直接加载预训练模型 如果我们使用的模型和原模型完全一样,那么我们可以直接加载别人训练好的模型: my_resnet = MyResNet(*args, **kwargs) my_resnet.load_state_dict(
-
pytorch 预训练层的使用方法
pytorch 预训练层的使用方法 将其他地方训练好的网络,用到新的网络里面 加载预训练网络 1.原先已经训练好一个网络 AutoEncoder_FC() 2.首先加载该网络,读取其存储的参数 3.设置一个参数集 cnnpre = AutoEncoder_FC() cnnpre.load_state_dict(torch.load('autoencoder_FC.pkl')['state_dict']) cnnpre_dict =cnnpre.state_dict() 加载新网络 1.设置新的网
-
基于pytorch 预训练的词向量用法详解
如何在pytorch中使用word2vec训练好的词向量 torch.nn.Embedding() 这个方法是在pytorch中将词向量和词对应起来的一个方法. 一般情况下,如果我们直接使用下面的这种: self.embedding = torch.nn.Embedding(num_embeddings=vocab_size, embedding_dim=embeding_dim) num_embeddings=vocab_size 表示词汇量的大小 embedding_dim=embeding
-
PyTorch预训练Bert模型的示例
本文介绍以下内容: 1. 使用transformers框架做预训练的bert-base模型: 2. 开发平台使用Google的Colab平台,白嫖GPU加速: 3. 使用datasets模块下载IMDB影评数据作为训练数据. transformers模块简介 transformers框架为Huggingface开源的深度学习框架,支持几乎所有的Transformer架构的预训练模型.使用非常的方便,本文基于此框架,尝试一下预训练模型的使用,简单易用. 本来打算预训练bert-large模型,发现
-
PyTorch预训练的实现
前言 最近使用PyTorch感觉妙不可言,有种当初使用Keras的快感,而且速度还不慢.各种设计直接简洁,方便研究,比tensorflow的臃肿好多了.今天让我们来谈谈PyTorch的预训练,主要是自己写代码的经验以及论坛PyTorch Forums上的一些回答的总结整理. 直接加载预训练模型 如果我们使用的模型和原模型完全一样,那么我们可以直接加载别人训练好的模型: my_resnet = MyResNet(*args, **kwargs) my_resnet.load_state_dict(
-
Transformer导论之Bert预训练语言解析
目录 Bert Pre-training BERT Fine-tuning BERT 代码实现 Bert BERT,全称为“Bidirectional Encoder Representations from Transformers”,是一种预训练语言表示的方法,意味着我们在一个大型文本语料库(如维基百科)上训练一个通用的“语言理解”模型,然后将该模型用于我们关心的下游NLP任务(如问答).BERT的表现优于之前的传统NLP方法,因为它是第一个用于预训练NLP的无监督的.深度双向系统. Ber
-
pytorch fine-tune 预训练的模型操作
之一: torchvision 中包含了很多预训练好的模型,这样就使得 fine-tune 非常容易.本文主要介绍如何 fine-tune torchvision 中预训练好的模型. 安装 pip install torchvision 如何 fine-tune 以 resnet18 为例: from torchvision import models from torch import nn from torch import optim resnet_model = models.resne
-
pytorch 修改预训练model实例
我就废话不多说了,直接上代码吧! class Net(nn.Module): def __init__(self , model): super(Net, self).__init__() #取掉model的后两层 self.resnet_layer = nn.Sequential(*list(model.children())[:-2]) self.transion_layer = nn.ConvTranspose2d(2048, 2048, kernel_size=14, stride=3)
随机推荐
- JavaScript中Textarea滚动条不能拖动的解决方法
- JavaScript实现渐变色效果(不使用图片)
- Java 数组分析及简单实例
- Python tkinter模块弹出窗口及传值回到主窗口操作详解
- JavaScript设置名字输入不合法的实现方法
- bootstrap table表格插件使用详解
- tangram框架响应式加载图片方法
- 发布一个高效的JavaScript分析、压缩工具 JavaScript Analyser
- Mysql 取字段值逗号第一个数据的查询语句
- 浅析如何利用JavaScript进行语音识别
- asp.net错误页面处理示例分享
- linux如何编译安装新内核支持NTFS文件系统(以redhat7.2x64为例)
- MYSQL必知必会读书笔记第四章之检索数据
- 谈谈JavaScript中function多重理解
- C++按位异或运算符的使用介绍
- js 倒计时(高效率服务器时间同步)
- PHP实现JS中escape与unescape的方法
- 一个PHP+MSSQL分页的例子
- Android编程设计模式之抽象工厂模式详解
- Python实现简易Web爬虫详解