在pytorch 中计算精度、回归率、F1 score等指标的实例
pytorch中训练完网络后,需要对学习的结果进行测试。官网上例程用的方法统统都是正确率,使用的是torch.eq()这个函数。
但是为了更精细的评价结果,我们还需要计算其他各个指标。在把官网API翻了一遍之后发现并没有用于计算TP,TN,FP,FN的函数。。。
在动了无数歪脑筋之后,心想pytorch完全支持numpy,那能不能直接进行判断,试了一下果然可以,上代码:
# TP predict 和 label 同时为1 TP += ((pred_choice == 1) & (target.data == 1)).cpu().sum() # TN predict 和 label 同时为0 TN += ((pred_choice == 0) & (target.data == 0)).cpu().sum() # FN predict 0 label 1 FN += ((pred_choice == 0) & (target.data == 1)).cpu().sum() # FP predict 1 label 0 FP += ((pred_choice == 1) & (target.data == 0)).cpu().sum() p = TP / (TP + FP) r = TP / (TP + FN) F1 = 2 * r * p / (r + p) acc = (TP + TN) / (TP + TN + FP + FN
这样就能看到各个指标了。
因为target是Variable所以需要用target.data取到对应的tensor,又因为是在gpu上算的,需要用 .cpu() 移到cpu上。
因为这是一个batch的统计,所以需要用+=累计出整个epoch的统计。当然,在epoch开始之前需要清零
以上这篇在pytorch 中计算精度、回归率、F1 score等指标的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。
相关推荐
-
Pytorch 计算误判率,计算准确率,计算召回率的例子
无论是官方文档还是各位大神的论文或搭建的网络很多都是计算准确率,很少有计算误判率, 下面就说说怎么计算准确率以及误判率.召回率等指标 1.计算正确率 获取每批次的预判正确个数 train_correct = (pred == batch_y.squeeze(1)).sum() 该语句的意思是 预测的标签与实际标签相等的总数 获取训练集总的预判正确个数 train_acc += train_correct.data[0] #用来计算正确率 准确率 : train_acc / (len(train_
-
在pytorch 中计算精度、回归率、F1 score等指标的实例
pytorch中训练完网络后,需要对学习的结果进行测试.官网上例程用的方法统统都是正确率,使用的是torch.eq()这个函数. 但是为了更精细的评价结果,我们还需要计算其他各个指标.在把官网API翻了一遍之后发现并没有用于计算TP,TN,FP,FN的函数... 在动了无数歪脑筋之后,心想pytorch完全支持numpy,那能不能直接进行判断,试了一下果然可以,上代码: # TP predict 和 label 同时为1 TP += ((pred_choice == 1) & (target.d
-
在Pytorch中计算自己模型的FLOPs方式
https://github.com/Lyken17/pytorch-OpCounter 安装方法很简单: pip install thop 基本用法: from torchvision.models import resnet50from thop import profile model = resnet50() flops, params = profile(model, input_size=(1, 3, 224,224)) 对自己的module进行特别的计算: class YourMo
-
在Pytorch中计算卷积方法的区别详解(conv2d的区别)
在二维矩阵间的运算: class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True) 对由多个特征平面组成的输入信号进行2D的卷积操作.详解 torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)
-
Pytorch中的自动求梯度机制和Variable类实例
自动求导机制是每一个深度学习框架中重要的性质,免去了手动计算导数,下面用代码介绍并举例说明Pytorch的自动求导机制. 首先介绍Variable,Variable是对Tensor的一个封装,操作和Tensor是一样的,但是每个Variable都有三个属性:Varibale的Tensor本身的.data,对应Tensor的梯度.grad,以及这个Variable是通过什么方式得到的.grad_fn,根据最新消息,在pytorch0.4更新后,torch和torch.autograd.Variab
-
在pytorch中计算准确率,召回率和F1值的操作
看代码吧~ predict = output.argmax(dim = 1) confusion_matrix =torch.zeros(2,2) for t, p in zip(predict.view(-1), target.view(-1)): confusion_matrix[t.long(), p.long()] += 1 a_p =(confusion_matrix.diag() / confusion_matrix.sum(1))[0] b_p = (confusion_matri
-
解决pytorch中的kl divergence计算问题
偶然从pytorch讨论论坛中看到的一个问题,KL divergence different results from tf,kl divergence 在TensorFlow中和pytorch中计算结果不同,平时没有注意到,记录下 一篇关于KL散度.JS散度以及交叉熵对比的文章 kl divergence 介绍 KL散度( Kullback–Leibler divergence),又称相对熵,是描述两个概率分布 P 和 Q 差异的一种方法.计算公式: 可以发现,P 和 Q 中元素的个数不用相等
-
解决pytorch GPU 计算过程中出现内存耗尽的问题
Pytorch GPU运算过程中会出现:"cuda runtime error(2): out of memory"这样的错误.通常,这种错误是由于在循环中使用全局变量当做累加器,且累加梯度信息的缘故,用官方的说法就是:"accumulate history across your training loop".在默认情况下,开启梯度计算的Tensor变量是会在GPU保持他的历史数据的,所以在编程或者调试过程中应该尽力避免在循环中累加梯度信息. 下面举个栗子: 上代
-
Pytorch中accuracy和loss的计算知识点总结
这几天关于accuracy和loss的计算有一些疑惑,原来是自己还没有弄清楚. 给出实例 def train(train_loader, model, criteon, optimizer, epoch): train_loss = 0 train_acc = 0 num_correct= 0 for step, (x,y) in enumerate(train_loader): # x: [b, 3, 224, 224], y: [b] x, y = x.to(device), y.to(de
-
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
公式 首先需要了解CrossEntropyLoss的计算过程,交叉熵的函数是这样的: 其中,其中yi表示真实的分类结果.这里只给出公式,关于CrossEntropyLoss的其他详细细节请参照其他博文. 测试代码(一维) import torch import torch.nn as nn import math criterion = nn.CrossEntropyLoss() output = torch.randn(1, 5, requires_grad=True) label = tor
-
pytorch中的卷积和池化计算方式详解
TensorFlow里面的padding只有两个选项也就是valid和same pytorch里面的padding么有这两个选项,它是数字0,1,2,3等等,默认是0 所以输出的h和w的计算方式也是稍微有一点点不同的:tf中的输出大小是和原来的大小成倍数关系,不能任意的输出大小:而nn输出大小可以通过padding进行改变 nn里面的卷积操作或者是池化操作的H和W部分都是一样的计算公式:H和W的计算 class torch.nn.MaxPool2d(kernel_size, stride=Non
随机推荐
- windows下定时利用bat脚本实现ftp上传下载
- ASP.NET使用ajax实现分页局部刷新页面功能
- 详解vue-router 路由元信息
- 简单实现js悬浮导航效果
- asp.net使用FCK编辑器中的分页符实现长文章分页功能
- js中将URL中的参数提取出来作为对象的实现代码
- js鼠标移动时禁止选中文字
- 优化mysql数据库的经验总结
- Android内存泄漏排查利器LeakCanary
- 把字符串按照特定的字母顺序进行排序的js代码
- PHP中的output_buffering详细介绍
- jQuery实现在列表的首行添加数据
- asp.net 半角全角转化工具
- Android几行代码实现监听微信聊天示例
- AjaxControlToolKit 显示浏览者本地语言的方法
- Android使alertDialog.builder不会点击外面和按返回键消失的方法
- jquery实现奇偶行赋值不同css值
- Java 8 开发的 Mybatis 注解代码生成工具
- MySQL按时间统计数据的方法总结
- python实现判断一个字符串是否是合法IP地址的示例