pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

公式

首先需要了解CrossEntropyLoss的计算过程,交叉熵的函数是这样的:

其中,其中yi表示真实的分类结果。这里只给出公式,关于CrossEntropyLoss的其他详细细节请参照其他博文。

测试代码(一维)

import torch
import torch.nn as nn
import math

criterion = nn.CrossEntropyLoss()
output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为5类:")
print(output)
print("要计算label的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = 0
for i in range(1):
  first = -output[i][label[i]]
second = 0
for i in range(1):
  for j in range(5):
    second += math.exp(output[i][j])
res = 0
res = (first + math.log(second))
print("自己的计算结果:")
print(res)

测试代码(多维)

import torch
import torch.nn as nn
import math
criterion = nn.CrossEntropyLoss()
output = torch.randn(3, 5, requires_grad=True)
label = torch.empty(3, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为3个5类:")
print(output)
print("要计算loss的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = [0, 0, 0]
for i in range(3):
  first[i] = -output[i][label[i]]
second = [0, 0, 0]
for i in range(3):
  for j in range(5):
    second[i] += math.exp(output[i][j])
res = 0
for i in range(3):
  res += (first[i] + math.log(second[i]))
print("自己的计算结果:")
print(res/3)

nn.CrossEntropyLoss()中的计算方法

注意:在计算CrossEntropyLosss时,真实的label(一个标量)被处理成onehot编码的形式。

在pytorch中,CrossEntropyLoss计算公式为:

CrossEntropyLoss带权重的计算公式为(默认weight=None):

以上这篇pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • Pytorch中accuracy和loss的计算知识点总结

    这几天关于accuracy和loss的计算有一些疑惑,原来是自己还没有弄清楚. 给出实例 def train(train_loader, model, criteon, optimizer, epoch): train_loss = 0 train_acc = 0 num_correct= 0 for step, (x,y) in enumerate(train_loader): # x: [b, 3, 224, 224], y: [b] x, y = x.to(device), y.to(de

  • pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

    公式 首先需要了解CrossEntropyLoss的计算过程,交叉熵的函数是这样的: 其中,其中yi表示真实的分类结果.这里只给出公式,关于CrossEntropyLoss的其他详细细节请参照其他博文. 测试代码(一维) import torch import torch.nn as nn import math criterion = nn.CrossEntropyLoss() output = torch.randn(1, 5, requires_grad=True) label = tor

  • SSM框架中测试单元的使用 spring整合Junit过程详解

    测试类中的问题和解决思路 问题 在测试类中,每个测试方法都有以下两行代码: ApplicationContext ac = new ClassPathXmlApplicationContext("bean.xml"); IAccountService as = ac.getBean("accountService",IAccountService.class); 这两行代码的作用是获取容器,如果不写的话,直接会提示空指针异常.所以又不能轻易删掉. 解决思路分析 针对

  • PyTorch的SoftMax交叉熵损失和梯度用法

    在PyTorch中可以方便的验证SoftMax交叉熵损失和对输入梯度的计算 关于softmax_cross_entropy求导的过程,可以参考HERE 示例: # -*- coding: utf-8 -*- import torch import torch.autograd as autograd from torch.autograd import Variable import torch.nn.functional as F import torch.nn as nn import nu

  • 解决pytorch 交叉熵损失输出为负数的问题

    网络训练中,loss曲线非常奇怪 交叉熵怎么会有负数. 经过排查,交叉熵不是有个负对数吗,当网络输出的概率是0-1时,正数.可当网络输出大于1的数,就有可能变成负数. 所以加上一行就行了 out1 = F.softmax(out1, dim=1) 补充知识:在pytorch框架下,训练model过程中,loss=nan问题时该怎么解决? 当我在UCF-101数据集训练alexnet时,epoch设为100,跑到三十多个epoch时,出现了loss=nan问题,当时是一脸懵逼,在查阅资料后,我通过

  • pytorch 中pad函数toch.nn.functional.pad()的用法

    padding操作是给图像外围加像素点. 为了实际说明操作过程,这里我们使用一张实际的图片来做一下处理. 这张图片是大小是(256,256),使用pad来给它加上一个黑色的边框.具体代码如下: import torch.nn,functional as F import torch from PIL import Image im=Image.open("heibai.jpg",'r') X=torch.Tensor(np.asarray(im)) print("shape:

  • pytorch中的卷积和池化计算方式详解

    TensorFlow里面的padding只有两个选项也就是valid和same pytorch里面的padding么有这两个选项,它是数字0,1,2,3等等,默认是0 所以输出的h和w的计算方式也是稍微有一点点不同的:tf中的输出大小是和原来的大小成倍数关系,不能任意的输出大小:而nn输出大小可以通过padding进行改变 nn里面的卷积操作或者是池化操作的H和W部分都是一样的计算公式:H和W的计算 class torch.nn.MaxPool2d(kernel_size, stride=Non

  • 在pytorch中对非叶节点的变量计算梯度实例

    在pytorch中一般只对叶节点进行梯度计算,也就是下图中的d,e节点,而对非叶节点,也即是c,b节点则没有显式地去保留其中间计算过程中的梯度(因为一般来说只有叶节点才需要去更新),这样可以节省很大部分的显存,但是在调试过程中,有时候我们需要对中间变量梯度进行监控,以确保网络的有效性,这个时候我们需要打印出非叶节点的梯度,为了实现这个目的,我们可以通过两种手段进行. 注册hook函数 Tensor.register_hook[2] 可以注册一个反向梯度传导时的hook函数,这个hook函数将会在

  • pytorch从头开始搭建UNet++的过程详解

    目录 Unet++代码 网络架构 Backbone 上采样 下采样 深度监督 网络架构代码 Unet是一个最近比较火的网络结构.它的理论已经有很多大佬在讨论了.本文主要从实际操作的层面,讲解pytorch从头开始搭建UNet++的过程. Unet++代码 网络架构 黑色部分是Backbone,是原先的UNet. 绿色箭头为上采样,蓝色箭头为密集跳跃连接. 绿色的模块为密集连接块,是经过左边两个部分拼接操作后组成的 Backbone 2个3x3的卷积,padding=1. class VGGBlo

  • C和C++中的基本数据类型的大小及表示范围详解

    本文研究的主要问题时关于C和C++中的基本数据类型int.long.long long.float.double.char.string的大小及表示范围,具体介绍如下. 一.基本类型的大小及范围的总结(以下所讲都是默认在32位操作系统下): 字节:byte:位:bit. 1.短整型short:所占内存大小:2byte=16bit: 所能表示范围:-32768~32767:(即-2^15~2^15-1) 2.整型int:所占内存大小:4byte=32bit: 所能表示范围:-2147483648~

  • PHP中的浅复制与深复制的实例详解

    PHP中的浅复制与深复制的实例详解 前言: 最近温习了一下Design Pattern方面的知识,在看到Prototype Pattern这一设计模式时,注意到其中涉及到一个浅复制与深复制的问题.这里来总结一下,提醒自己以后一定要多加注意. 自PHP5起,new运算符自动返回一个引用,一个 对象变量 已经不再保存整个对象的值,只是保存一个标识符来访问真正的对象内容.当对象作为参数传递,作为结果返回,或者赋值给另外一个变量,另外一个变量跟原来的不是引用的关系,只是他们都保存着同一个标识符的拷贝,这

随机推荐