pytorch实现focal loss的两种方式小结

我就废话不多说了,直接上代码吧!

import torch
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
'''
pytorch实现focal loss的两种方式(现在讨论的是基于分割任务)
在计算损失函数的过程中考虑到类别不平衡的问题,假设加上背景类别共有6个类别
'''
def compute_class_weights(histogram):
  classWeights = np.ones(6, dtype=np.float32)
  normHist = histogram / np.sum(histogram)
  for i in range(6):
    classWeights[i] = 1 / (np.log(1.10 + normHist[i]))
  return classWeights
def focal_loss_my(input,target):
  '''
  :param input: shape [batch_size,num_classes,H,W] 仅仅经过卷积操作后的输出,并没有经过任何激活函数的作用
  :param target: shape [batch_size,H,W]
  :return:
  '''
  n, c, h, w = input.size()

  target = target.long()
  input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
  target = target.contiguous().view(-1)

  number_0 = torch.sum(target == 0).item()
  number_1 = torch.sum(target == 1).item()
  number_2 = torch.sum(target == 2).item()
  number_3 = torch.sum(target == 3).item()
  number_4 = torch.sum(target == 4).item()
  number_5 = torch.sum(target == 5).item()

  frequency = torch.tensor((number_0, number_1, number_2, number_3, number_4, number_5), dtype=torch.float32)
  frequency = frequency.numpy()
  classWeights = compute_class_weights(frequency)
  '''
  根据当前给出的ground truth label计算出每个类别所占据的权重
  '''

  # weights=torch.from_numpy(classWeights).float().cuda()
  weights = torch.from_numpy(classWeights).float()
  focal_frequency = F.nll_loss(F.softmax(input, dim=1), target, reduction='none')
  '''
  上面一篇博文讲过
  F.nll_loss(torch.log(F.softmax(inputs, dim=1),target)的函数功能与F.cross_entropy相同
  可见F.nll_loss中实现了对于target的one-hot encoding编码功能,将其编码成与input shape相同的tensor
  然后与前面那一项(即F.nll_loss输入的第一项)进行 element-wise production
  相当于取出了 log(p_gt)即当前样本点被分类为正确类别的概率
  现在去掉取log的操作,相当于 focal_frequency shape [num_samples]
  即取出ground truth类别的概率数值,并取了负号
  '''

  focal_frequency += 1.0#shape [num_samples] 1-P(gt_classes)

  focal_frequency = torch.pow(focal_frequency, 2) # torch.Size([75])
  focal_frequency = focal_frequency.repeat(c, 1)
  '''
  进行repeat操作后,focal_frequency shape [num_classes,num_samples]
  '''
  focal_frequency = focal_frequency.transpose(1, 0)
  loss = F.nll_loss(focal_frequency * (torch.log(F.softmax(input, dim=1))), target, weight=None,
           reduction='elementwise_mean')
  return loss

def focal_loss_zhihu(input, target):
  '''
  :param input: 使用知乎上面大神给出的方案 https://zhuanlan.zhihu.com/p/28527749
  :param target:
  :return:
  '''
  n, c, h, w = input.size()

  target = target.long()
  inputs = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
  target = target.contiguous().view(-1)

  N = inputs.size(0)
  C = inputs.size(1)

  number_0 = torch.sum(target == 0).item()
  number_1 = torch.sum(target == 1).item()
  number_2 = torch.sum(target == 2).item()
  number_3 = torch.sum(target == 3).item()
  number_4 = torch.sum(target == 4).item()
  number_5 = torch.sum(target == 5).item()

  frequency = torch.tensor((number_0, number_1, number_2, number_3, number_4, number_5), dtype=torch.float32)
  frequency = frequency.numpy()
  classWeights = compute_class_weights(frequency)

  weights = torch.from_numpy(classWeights).float()
  weights=weights[target.view(-1)]#这行代码非常重要

  gamma = 2

  P = F.softmax(inputs, dim=1)#shape [num_samples,num_classes]

  class_mask = inputs.data.new(N, C).fill_(0)
  class_mask = Variable(class_mask)
  ids = target.view(-1, 1)
  class_mask.scatter_(1, ids.data, 1.)#shape [num_samples,num_classes] one-hot encoding

  probs = (P * class_mask).sum(1).view(-1, 1)#shape [num_samples,]
  log_p = probs.log()

  print('in calculating batch_loss',weights.shape,probs.shape,log_p.shape)

  # batch_loss = -weights * (torch.pow((1 - probs), gamma)) * log_p
  batch_loss = -(torch.pow((1 - probs), gamma)) * log_p

  print(batch_loss.shape)

  loss = batch_loss.mean()
  return loss

if __name__=='__main__':
  pred=torch.rand((2,6,5,5))
  y=torch.from_numpy(np.random.randint(0,6,(2,5,5)))
  loss1=focal_loss_my(pred,y)
  loss2=focal_loss_zhihu(pred,y)

  print('loss1',loss1)
  print('loss2', loss2)
'''
in calculating batch_loss torch.Size([50]) torch.Size([50, 1]) torch.Size([50, 1])
torch.Size([50, 1])
loss1 tensor(1.3166)
loss2 tensor(1.3166)
'''

以上这篇pytorch实现focal loss的两种方式小结就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • Pytorch中accuracy和loss的计算知识点总结

    这几天关于accuracy和loss的计算有一些疑惑,原来是自己还没有弄清楚. 给出实例 def train(train_loader, model, criteon, optimizer, epoch): train_loss = 0 train_acc = 0 num_correct= 0 for step, (x,y) in enumerate(train_loader): # x: [b, 3, 224, 224], y: [b] x, y = x.to(device), y.to(de

  • pytorch实现focal loss的两种方式小结

    我就废话不多说了,直接上代码吧! import torch import torch.nn.functional as F import numpy as np from torch.autograd import Variable ''' pytorch实现focal loss的两种方式(现在讨论的是基于分割任务) 在计算损失函数的过程中考虑到类别不平衡的问题,假设加上背景类别共有6个类别 ''' def compute_class_weights(histogram): classWeigh

  • laravel实现上传图片的两种方式小结

    第一:是laravel里面自带的上传方式(写在接口里面的) function uploadAvatar(Request $request) { $user_id = Auth::id(); $avatar = $request->file('avatar')->store('/public/' . date('Y-m-d') . '/avatars'); //上传的头像字段avatar是文件类型 $avatar = Storage::url($avatar);//就是很简单的一个步骤 $res

  • vue data引入本地图片的两种方式小结

    我就废话不多说了,大家直接看吧! 第一种 <template> <img :src="imgsrc"> </template> <script> export default { data () { return { imgsrc: require('../../images/ICON-electronicbilling.png') } } } </script> 第二种 <template> <img :s

  • Springboot之修改启动端口的两种方式(小结)

    Springboot启动的时候,端口的设定默认是8080,这肯定是不行的,我们需要自己定义端口,Springboot提供了两种方式,第一种,我们可以通过application.yml配置文件配置,第二种,可以通过代码里面指定,在开发中,建议使用修改application.yml的方式来修改端口. 代码地址 #通过yml配置文件的方式指定端口地址 https://gitee.com/yellowcong/springboot-demo/tree/master/springboot-demo2 #硬

  • Spring MVC获取HTTP请求头的两种方式小结

    1 前言 请求是任何Web服务要关注的对象,而请求头也是其中非常重要的信息.本文将通过代码讲解如何在Spring MVC项目中获取请求头的内容.主要通过两种方式获取: (1)通过注解@RequestHeader获取,需要在Controller中显式获取: (2)通过RequestContextHolder获取,可以任何地方获取. 接下来通过代码讲解. 2 通过注解@RequestHeader获取 需要在Controller中显示使用@RequestHeader. 2.1 获取某个请求头 只获取其

  • redis实现延时队列的两种方式(小结)

    背景 项目中的流程监控,有几种节点,需要监控每一个节点是否超时.按传统的做法,肯定是通过定时任务,去扫描然后判断,但是定时任务有缺点:1,数据量大会慢:2,时间不好控制,太短,怕一次处理不完,太长状态就会有延迟.所以就想到用延迟队列的方式去实现. 一,redis的过期key监控 1,开启过期key监听 在redis的配置里把这个注释去掉 notify-keyspace-events Ex 然后重启redis 2,使用redis过期监听实现延迟队列 继承KeyExpirationEventMess

  • C#复杂XML反序列化为实体对象两种方式小结

    目录 前言 需要操作的Xml数据 一.通过是手写的方式去定义Xml的实体对象模型类 二.通过Visual Studio自带的生成Xml实体对象模型类 1.首先Ctrl+C复制你需要生成的Xml文档内容 2.找到编辑=>选择性粘贴=>将Xml粘贴为类 3.以下是使用VS自动生成的Xml类 验证两个Xml类是否能够反序列化成功 C# XML基础入门(XML文件内容增删改查清) C#XmlHelper帮助类操作Xml文档的通用方法汇总 .NET中XML序列化和反序列化常用类和用来控制XML序列化的属

  • Vue引入并使用Element组件库的两种方式小结

    目录 前言 Element-ui(饿了么ui) 安装element-ui 引入element-ui 完整引入element-u 按需引入element-ui 总结 前言 在开发的时候,虽然我们可以自己写css或者js甚至一些动画特效,但是也有很多开源的组件库帮我们写好了.我们只需要下载并引入即可. vue和element-ui在开发中是比较般配的,也是我们开发中用的很多的,下面就介绍下如何在eue项目中引入element-ui组件库 Element-ui(饿了么ui) element-ui(饿了

  • java实现消息队列的两种方式(小结)

    实现消息队列的两种方式 Apache ActiveMQ官方实例发送消息 直接在Apache官网http://activemq.apache.org/download-archives.html下载ActiveMQ源码 下载解压后拿到java代码实例 然后倒入IDE 如下: 请认真阅读readme.md文件,大致意思就是把项目打成两个jar包,然后启动服务,然后同时运行打的两个jar包,然后就能看到具体的调用信息.打jar包时直接利用maven打就行了,不用修改代码. 启动服务: 利用Spring

  • Java中Http连接的两种方式(小结)

    在java中连接http,介绍两种方法,一种是java的HttpUrlConnection,另一种是apacha公司的httpClient,后者是第三方的类库需要从外部,导入,同时这也是第一次使用外部的类库,以后还会有很多需要导入外部类库的需求. http协议是基于tcp的一种协议. tcp是一种保证可靠连接的传输协议,通过三次握手,和丢失重传的机制保证数据的传输. 首先来看HttpUrlConnection 这个类是java自带的,直接import就行. 使用tcp连接的过程几乎都一样,htt

随机推荐