基于MSELoss()与CrossEntropyLoss()的区别详解

基于pytorch来讲

MSELoss()多用于回归问题,也可以用于one_hotted编码形式,

CrossEntropyLoss()名字为交叉熵损失函数,不用于one_hotted编码形式

MSELoss()要求batch_x与batch_y的tensor都是FloatTensor类型

CrossEntropyLoss()要求batch_x为Float,batch_y为LongTensor类型

(1)CrossEntropyLoss() 举例说明:

比如二分类问题,最后一层输出的为2个值,比如下面的代码:

class CNN (nn.Module ) :
  def __init__ ( self , hidden_size1 , output_size , dropout_p) :
    super ( CNN , self ).__init__ ( )
    self.hidden_size1 = hidden_size1
    self.output_size = output_size
    self.dropout_p = dropout_p

    self.conv1 = nn.Conv1d ( 1,8,3,padding =1)
    self.fc1 = nn.Linear (8*500, self.hidden_size1 )
    self.out = nn.Linear (self.hidden_size1,self.output_size ) 

  def forward ( self , encoder_outputs ) :
    cnn_out = F.max_pool1d ( F.relu (self.conv1(encoder_outputs)),2)
    cnn_out = F.dropout ( cnn_out ,self.dropout_p) #加一个dropout
    cnn_out = cnn_out.view (-1,8*500)
    output_1 = torch.tanh ( self.fc1 ( cnn_out ) )
    output = self.out ( ouput_1)
    return output

最后的输出结果为:

上面一个tensor为output结果,下面为target,没有使用one_hotted编码。

训练过程如下:

cnn_optimizer = torch.optim.SGD(cnn.parameters(),learning_rate,momentum=0.9,\
              weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()

def train ( input_variable , target_variable , cnn , cnn_optimizer , criterion ) :
  cnn_output = cnn( input_variable )
  print(cnn_output)
  print(target_variable)
  loss = criterion ( cnn_output , target_variable)
  cnn_optimizer.zero_grad ()
  loss.backward( )
  cnn_optimizer.step( )
  #print('loss: ',loss.item())
  return loss.item() #返回损失

说明CrossEntropyLoss()是output两位为one_hotted编码形式,但target不是one_hotted编码形式。

(2)MSELoss() 举例说明:

网络结构不变,但是标签是one_hotted编码形式。下面的图仅做说明,网络结构不太对,出来的预测也不太对。

如果target不是one_hotted编码形式会报错,报的错误如下。

目前自己理解的两者的区别,就是这样的,至于多分类问题是不是也是样的有待考察。

以上这篇基于MSELoss()与CrossEntropyLoss()的区别详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

    公式 首先需要了解CrossEntropyLoss的计算过程,交叉熵的函数是这样的: 其中,其中yi表示真实的分类结果.这里只给出公式,关于CrossEntropyLoss的其他详细细节请参照其他博文. 测试代码(一维) import torch import torch.nn as nn import math criterion = nn.CrossEntropyLoss() output = torch.randn(1, 5, requires_grad=True) label = tor

  • 基于MSELoss()与CrossEntropyLoss()的区别详解

    基于pytorch来讲 MSELoss()多用于回归问题,也可以用于one_hotted编码形式, CrossEntropyLoss()名字为交叉熵损失函数,不用于one_hotted编码形式 MSELoss()要求batch_x与batch_y的tensor都是FloatTensor类型 CrossEntropyLoss()要求batch_x为Float,batch_y为LongTensor类型 (1)CrossEntropyLoss() 举例说明: 比如二分类问题,最后一层输出的为2个值,比

  • 基于session_unset与session_destroy的区别详解

    session_unset()释放当前在内存中已经创建的所有$_SESSION变量,但不删除session文件以及不释放对应的sessionidsession_destroy()删除当前用户对应的session文件以及释放sessionid,内存中的$_SESSION变量内容依然保留因此,释放用户的session所有资源,需要顺序执行如下代码:程序代码 复制代码 代码如下: <?php$_SESSION['user'] = 'wangh';session_unset();session_dest

  • 基于python中staticmethod和classmethod的区别(详解)

    例子 class A(object): def foo(self,x): print "executing foo(%s,%s)"%(self,x) @classmethod def class_foo(cls,x): print "executing class_foo(%s,%s)"%(cls,x) @staticmethod def static_foo(x): print "executing static_foo(%s)"%x a=A(

  • 基于DOM节点删除之empty和remove的区别(详解)

    要移除页面上节点是开发者常见的操作,jQuery提供了几种不同的方法用来处理这个问题,这里我们开仔细了解下empty和remove方法 empty 顾名思义,清空方法,但是与删除又有点不一样,因为它只移除了 指定元素中的所有子节点. 这个方法不仅移除子元素(和其他后代元素),同样移除元素里的文本.因为,根据说明,元素里任何文本字符串都被看做是该元素的子节点.请看下面的HTML: <div class="hello"><p>这是p标签</p></

  • 基于js中this和event 的区别(详解)

    今天在看javascript入门经典-事件一章中看到了 this 和 event 两种传参形式.因为作为一个初级的前端开发人员平时只用过 this传参,so很想弄清楚,this和event的区别是什么,什么情况下用什么比较合适. onclick = changeImg(this)       vs     onclick = changeImg(event) <img src='usa.gif' onclick="changeImg(event)" /> <scrip

  • 基于Python __dict__与dir()的区别详解

    Python下一切皆对象,每个对象都有多个属性(attribute),Python对属性有一套统一的管理方案. __dict__与dir()的区别: dir()是一个函数,返回的是list: __dict__是一个字典,键为属性名,值为属性值: dir()用来寻找一个对象的所有属性,包括__dict__中的属性,__dict__是dir()的子集: 并不是所有对象都拥有__dict__属性.许多内建类型就没有__dict__属性,如list,此时就需要用dir()来列出对象的所有属性. __di

  • 基于Java中throw和throws的区别(详解)

    系统自动抛出的异常 所有系统定义的编译和运行异常都可以由系统自动抛出,称为标准异常,并且 Java 强烈地要求应用程序进行完整的异常处理,给用户友好的提示,或者修正后使程序继续执行. 语句抛出的异常 用户程序自定义的异常和应用程序特定的异常,必须借助于 throws 和 throw 语句来定义抛出异常. throw是语句抛出一个异常. 语法:throw (异常对象); throw e; throws是方法可能抛出异常的声明.(用在声明方法时,表示该方法可能要抛出异常) 语法:[(修饰符)](返回

  • 基于js中style.width与offsetWidth的区别(详解)

    作为一个初学者,经常会遇到在获取某一元素的宽度(高度.top值...)时,到底是用 style.width还是offsetWidth的疑惑. 1. 当样式写在行内的时候,如 <div id="box" style="width:100px">时,用 style.width或者offsetWidth都可以获取元素的宽度. 但是,当样式写在样式表中时,如 #box{ width: 100px; }, 此时只能用offsetWidth来获取元素的宽度,而sty

  • 基于js 字符串indexof与search方法的区别(详解)

    1.indexof方法 indexOf() 方法可返回某个指定的字符串值在字符串中首次出现的位置. 语法: 注意:有可选的参数(即设置开始的检索位置). 2.search方法 search() 方法用于检索字符串中指定的子字符串,或检索与正则表达式相匹配的子字符串. 注意:search方法可以根据正则表达式查找指定字符串(可以忽略大小写,并且不执行全局检索),同时没有可选参数(即设置开始的检索位置). 以上这篇基于js 字符串indexof与search方法的区别(详解)就是小编分享给大家的全部

  • 基于Django filter中用contains和icontains的区别(详解)

    qs.filter(name__contains="e") qs.filter(name__icontains="e") 对应sql 'contains': 'LIKE BINARY %s', 'icontains': 'LIKE %s', 其中的BINARY是 精确大小写 而'icontains'中的'i'表示忽略大小写 以上这篇基于Django filter中用contains和icontains的区别(详解)就是小编分享给大家的全部内容了,希望能给大家一个参考

随机推荐