pytorch中的inference使用实例
这里inference两个程序的连接,如目标检测,可以利用一个程序提取候选框,然后把候选框输入到分类cnn网络中。
这里常需要进行一定的连接。
#加载训练好的分类CNN网络 model=torch.load('model.pkl') #假设proposal_img是我们提取的候选框,是需要输入到CNN网络的数据 #先定义transforms对输入cnn的网络数据进行处理,常包括resize、totensor等操作 data_transforms=transforms.Compose([transforms.RandomSizedCrop(224), transforms.ToTensor()]) #由于transforms是对PIL格式数据操作,所以必要时转化格式 def tensor_to_PIL(tensor): image = tensor.cpu().clone() image = image.squeeze(0) image = unloader(image) return image #unqueeze(0)是加多一维,对应原来batchsiaze data=data_transforms(proposal_img).unqueeze(0) #新版本pytorch已经不用variable,可以省略这句 data=Variable(data) #貌似这句也是多余的 torch.no_grad() predict=F.softmax(model(data.cuda()).cuda())
以上这篇pytorch中的inference使用实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。
相关推荐
-
pytorch数据预处理错误的解决
出错: Traceback (most recent call last): File "train.py", line 305, in <module> train_model(model_conv, criterion, optimizer_conv, exp_lr_scheduler) File "train.py", line 145, in train_model for inputs, age_labels, gender_labels in
-
浅谈PyTorch的可重复性问题(如何使实验结果可复现)
由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致.因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子. 许多博客都有介绍如何解决这个问题,但是很多都不够全面,往往不能保证结果精确一致.我经过许多调研和实验,总结了以下方法,记录下来. 全部设置可以分为三部分: 1. CUDNN cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率.如果需要保证可重复性,可以使用如下设置: from torch.backends import cu
-
pytorch实现保证每次运行使用的随机数都相同
其实在代码的开头添加下面几句话即可: # 保证训练时获取的随机数都是一样的 init_seed = 1 torch.manual_seed(init_seed) torch.cuda.manual_seed(init_seed) np.random.seed(init_seed) # 用于numpy的随机数 torch.manual_seed(seed) 为了生成随机数设置种子.返回一个torch.Generator对象 参数: seed (int) – 期望的种子数 torch.cuda.ma
-
pytorch模型存储的2种实现方法
1.保存整个网络结构信息和模型参数信息: torch.save(model_object, './model.pth') 直接加载即可使用: model = torch.load('./model.pth') 2.只保存网络的模型参数-推荐使用 torch.save(model_object.state_dict(), './params.pth') 加载则要先从本地网络模块导入网络,然后再加载参数: from models import AgeModel model = AgeModel()
-
pytorch中的inference使用实例
这里inference两个程序的连接,如目标检测,可以利用一个程序提取候选框,然后把候选框输入到分类cnn网络中. 这里常需要进行一定的连接. #加载训练好的分类CNN网络 model=torch.load('model.pkl') #假设proposal_img是我们提取的候选框,是需要输入到CNN网络的数据 #先定义transforms对输入cnn的网络数据进行处理,常包括resize.totensor等操作 data_transforms=transforms.Compose([trans
-
pytorch中的transforms模块实例详解
pytorch中的transforms模块中包含了很多种对图像数据进行变换的函数,这些都是在我们进行图像数据读入步骤中必不可少的,下面我们讲解几种最常用的函数,详细的内容还请参考pytorch官方文档(放在文末). data_transforms = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms
-
pytorch中图像的数据格式实例
计算机视觉方面朋友都需要跟图像打交道,在pytorch中图像与我们平时在matlab中见到的图像数据格式有所不同.matlab中我们通常使用函数imread()来轻松地读入一张图像,我们在变量空间中可看到数据的存储方式是H x W x C的顺序(其中H.W.C分别表示图像的高.宽和通道数,通道数一般为RGB三通道),另外,其中的每一个数据都是[0,255]的整数. 在使用pytorch的时候,我们通常要使用pytorch中torchvision包下面的datasets模块和transforms模
-
PyTorch中torch.nn.Linear实例详解
目录 前言 1. nn.Linear的原理: 2. nn.Linear的使用: 3. nn.Linear的源码定义: 补充:许多细节需要声明 总结 前言 在学习transformer时,遇到过非常频繁的nn.Linear()函数,这里对nn.Linear进行一个详解.参考:https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html 1. nn.Linear的原理: 从名称就可以看出来,nn.Linear表示的是线性变
-
pytorch中permute()函数用法实例详解
目录 前言 三维情况 变化一:不改变任何参数 变化二:1与2交换 变化三:0与1交换 变化四:0与2交换 变化五:0与1交换,1与2交换 变化六:0与1交换,0与2交换 总结 前言 本文只讨论二维三维中的permute用法 最近的Attention学习中的一个permute函数让我不理解 这个光说太抽象 我就结合代码与图片解释一下 首先创建一个三维数组小实例 import torch x = torch.linspace(1, 30, steps=30).view(3,2,5) # 设置一个三维
-
如何从PyTorch中获取过程特征图实例详解
目录 一.获取Tensor ①类型转换 ②张量拆解 ③图像展示 总结 一.获取Tensor 神经网络在运算过程中实际上是以Tensor为格式进行计算的,我们只需稍稍改动一下forward函数即可从运算过程中抓到Tensor 代码如下: base_feature = self.extractor.forward(x) #正常的前向传递 feature=base_feature.detach() #抓取tensor feature_imshow(feature) #展示函数(关键代码) 通过将过程张
-
pytorch中获取模型input/output shape实例
Pytorch官方目前无法像tensorflow, caffe那样直接给出shape信息,详见 https://github.com/pytorch/pytorch/pull/3043 以下代码算一种workaround.由于CNN, RNN等模块实现不一样,添加其他模块支持可能需要改代码. 例如RNN中bias是bool类型,其权重也不是存于weight属性中,不过我们只关注shape够用了. 该方法必须构造一个输入调用forward后(model(x)调用)才可获取shape #coding
-
pytorch中的自定义反向传播,求导实例
pytorch中自定义backward()函数.在图像处理过程中,我们有时候会使用自己定义的算法处理图像,这些算法多是基于numpy或者scipy等包. 那么如何将自定义算法的梯度加入到pytorch的计算图中,能使用Loss.backward()操作自动求导并优化呢.下面的代码展示了这个功能` import torch import numpy as np from PIL import Image from torch.autograd import gradcheck class Bicu
-
在pytorch中对非叶节点的变量计算梯度实例
在pytorch中一般只对叶节点进行梯度计算,也就是下图中的d,e节点,而对非叶节点,也即是c,b节点则没有显式地去保留其中间计算过程中的梯度(因为一般来说只有叶节点才需要去更新),这样可以节省很大部分的显存,但是在调试过程中,有时候我们需要对中间变量梯度进行监控,以确保网络的有效性,这个时候我们需要打印出非叶节点的梯度,为了实现这个目的,我们可以通过两种手段进行. 注册hook函数 Tensor.register_hook[2] 可以注册一个反向梯度传导时的hook函数,这个hook函数将会在
-
在pytorch 中计算精度、回归率、F1 score等指标的实例
pytorch中训练完网络后,需要对学习的结果进行测试.官网上例程用的方法统统都是正确率,使用的是torch.eq()这个函数. 但是为了更精细的评价结果,我们还需要计算其他各个指标.在把官网API翻了一遍之后发现并没有用于计算TP,TN,FP,FN的函数... 在动了无数歪脑筋之后,心想pytorch完全支持numpy,那能不能直接进行判断,试了一下果然可以,上代码: # TP predict 和 label 同时为1 TP += ((pred_choice == 1) & (target.d
随机推荐
- Backbone.js中的集合详解
- python开发中range()函数用法实例分析
- HTML 编辑器 FCKeditor使用详解
- 图文详解SQL Server 2008R2使用教程
- jQuery plugin animsition使用小结
- 查看驱动器(盘符)的批处理
- 微信支付开发动态链接Native支付
- asp datediff 时间相减
- Eclipse中自动重构实现探索
- CentOS6.5下安装Mysql5.7.18的教程详解
- Mongoose中document与object的区别示例详解
- PHP数组生成XML格式数据的封装类实例
- SQL Server 数据库安全管理介绍
- mysql出现“Incorrect key file for table”处理方法
- 正负小数点后两位浮点数实现原理及代码
- jQuery源码分析-05异步队列 Deferred 使用介绍
- 删除文件提示文件正在被另一个人或程序使用的解决方法
- Hibernate单表操作实例解析
- Asp.net core中实现自动更新的Option的方法示例
- PHP框架Laravel中使用UUID实现数据分表操作示例