keras自定义损失函数并且模型加载的写法介绍

keras自定义函数时候,正常在模型里自己写好自定义的函数,然后在模型编译的那行代码里写上接口即可。如下所示,focal_loss和fbeta_score是我们自己定义的两个函数,在model.compile加入它们,metrics里‘accuracy'是keras自带的度量函数。

def focal_loss():
 ...
 return xx
def fbeta_score():
 ...
 return yy
model.compile(optimizer=Adam(lr=0.0001), loss=[focal_loss],metrics=['accuracy',fbeta_score] )

训练好之后,模型加载也需要再额外加一行,通过load_model里的custom_objects将我们定义的两个函数以字典的形式加入就能正常加载模型啦。

weight_path = './weights.h5'
model = load_model(weight_path,custom_objects={'focal_loss': focal_loss,'fbeta_score':fbeta_score})

补充知识:keras如何使用自定义的loss及评价函数进行训练及预测

1.有时候训练模型,现有的损失及评估函数并不足以科学的训练评估模型,这时候就需要自定义一些损失评估函数,比如focal loss损失函数及dice评价函数 for unet的训练。

2.在训练建模中导入自定义loss及评估函数。

#模型编译时加入自定义loss及评估函数
model.compile(optimizer = Adam(lr=1e-4), loss=[binary_focal_loss()],
    metrics=['accuracy',dice_coef])

#自定义loss及评估函数
def binary_focal_loss(gamma=2, alpha=0.25):
 """
 Binary form of focal loss.
 适用于二分类问题的focal loss
 focal_loss(p_t) = -alpha_t * (1 - p_t)**gamma * log(p_t)
  where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
 References:
  https://arxiv.org/pdf/1708.02002.pdf
 Usage:
  model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
 """
 alpha = tf.constant(alpha, dtype=tf.float32)
 gamma = tf.constant(gamma, dtype=tf.float32)

 def binary_focal_loss_fixed(y_true, y_pred):
  """
  y_true shape need be (None,1)
  y_pred need be compute after sigmoid
  """
  y_true = tf.cast(y_true, tf.float32)
  alpha_t = y_true * alpha + (K.ones_like(y_true) - y_true) * (1 - alpha)

  p_t = y_true * y_pred + (K.ones_like(y_true) - y_true) * (K.ones_like(y_true) - y_pred) + K.epsilon()
  focal_loss = - alpha_t * K.pow((K.ones_like(y_true) - p_t), gamma) * K.log(p_t)
  return K.mean(focal_loss)

 return binary_focal_loss_fixed

#'''
#smooth 参数防止分母为0
def dice_coef(y_true, y_pred, smooth=1):
 intersection = K.sum(y_true * y_pred, axis=[1,2,3])
 union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
 return K.mean( (2. * intersection + smooth) / (union + smooth), axis=0)

注意在模型保存时,记录的loss函数名称:你猜是哪个

a:binary_focal_loss()

b:binary_focal_loss_fixed

3.模型预测时,也要加载自定义loss及评估函数,不然会报错。

该告诉上面的答案了,保存在模型中loss的名称为:binary_focal_loss_fixed,在模型预测时,定义custom_objects字典,key一定要与保存在模型中的名称一致,不然会找不到loss function。所以自定义函数时,尽量避免使用我这种函数嵌套的方式,免得带来一些意想不到的烦恼。

model = load_model('./unet_' + label + '_20.h5',custom_objects={'binary_focal_loss_fixed': binary_focal_loss(),'dice_coef': dice_coef})

以上这篇keras自定义损失函数并且模型加载的写法介绍就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • 使用Keras加载含有自定义层或函数的模型操作

    当我们导入的模型含有自定义层或者自定义函数时,需要使用custom_objects来指定目标层或目标函数. 例如: 我的一个模型含有自定义层"SincConv1D",需要使用下面的代码导入: from keras.models import load_model model = load_model('model.h5', custom_objects={'SincConv1D': SincConv1D}) 如果不加custom_objects指定目标层Layer,则会出现以下报错:

  • 浅谈keras中自定义二分类任务评价指标metrics的方法以及代码

    对于二分类任务,keras现有的评价指标只有binary_accuracy,即二分类准确率,但是评估模型的性能有时需要一些其他的评价指标,例如精确率,召回率,F1-score等等,因此需要使用keras提供的自定义评价函数功能构建出针对二分类任务的各类评价指标. keras提供的自定义评价函数功能需要以如下两个张量作为输入,并返回一个张量作为输出. y_true:数据集真实值组成的一阶张量. y_pred:数据集输出值组成的一阶张量. tf.round()可对张量四舍五入,因此tf.round(

  • keras 自定义loss损失函数,sample在loss上的加权和metric详解

    首先辨析一下概念: 1. loss是整体网络进行优化的目标, 是需要参与到优化运算,更新权值W的过程的 2. metric只是作为评价网络表现的一种"指标", 比如accuracy,是为了直观地了解算法的效果,充当view的作用,并不参与到优化过程 在keras中实现自定义loss, 可以有两种方式,一种自定义 loss function, 例如: # 方式一 def vae_loss(x, x_decoded_mean): xent_loss = objectives.binary_

  • Keras之自定义损失(loss)函数用法说明

    在Keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在Keras中是固定的,须如下形式: def my_loss(y_true, y_pred): # y_true: True labels. TensorFlow/Theano tensor # y_pred: Predictions. TensorFlow/Theano tensor of the same shape as y_true . . . return scalar #返回一个标量

  • keras自定义损失函数并且模型加载的写法介绍

    keras自定义函数时候,正常在模型里自己写好自定义的函数,然后在模型编译的那行代码里写上接口即可.如下所示,focal_loss和fbeta_score是我们自己定义的两个函数,在model.compile加入它们,metrics里'accuracy'是keras自带的度量函数. def focal_loss(): ... return xx def fbeta_score(): ... return yy model.compile(optimizer=Adam(lr=0.0001), lo

  • Android自定义webView头部进度加载效果

    不多说先来看下效果图: 1. 颜色渐变加载进度条(夜神模拟器) 绿色加载进度条(魅蓝note2) 看图说话: 上图是不是加载网页的时候会有一个进度条在横向加载,比以前网速不好的时候是一片空白给人的感觉友好多了是不,然后效果还不错. 实现思路 就是自己画一条进度线(大家应该都会吧)然后加载到WebView的上面,开始进度条是隐藏的,进度线初始值为1,然后为了效果好一点,初始少于10的进度都让它加载到10的位置,等进度到100的时候0.2秒后隐藏. 请记得添加网络权限: <uses-permissi

  • Android自定义View实现圆形加载进度条

    本文实例为大家分享了Android自定义View实现圆形加载进度条的具体代码,供大家参考,具体内容如下 效果图 话不多说,咱们直接看代码 首先第一种: 1.创建自定义View类 public class MyRelative extends View {        public MyRelative(Context context) {         this(context, null); //手动改成this...     }       public MyRelative(Conte

  • Angular4如何自定义首屏的加载动画详解

    前言 相信大家都知道,在默认情况下,Angular应用程序在首次加载根组件时,会在浏览器的显示一个loading... 我们可以轻松地将loading修改成我们自己定义的动画.下面话不多说,来一起看看详细的介绍: 这是我们要实现首次加载的效果: 根组件标签中的内容 请注意:在你的入口文件index.html中,默认的loading...只是插入到根组件标签之间: <!doctype html> <html> <head> <meta charset="u

  • 编写自定义的Django模板加载器的简单示例

    Djangos 内置的模板加载器(在先前的模板加载内幕章节有叙述)通常会满足你的所有的模板加载需求,但是如果你有特殊的加载需求的话,编写自己的模板加载器也会相当简单. 比如:你可以从数据库中,或者利用Python的绑定直接从Subversion库中,更或者从一个ZIP文档中加载模板. 模板加载器,也就是 TEMPLATE_LOADERS 中的每一项,都要能被下面这个接口调用: load_template_source(template_name, template_dirs=None) 参数 t

  • Android 自定义标题栏 显示网页加载进度的方法实例

    这阵子在做Lephone的适配,测试组提交一个bug:标题栏的文字较长时没有显示完全,其实这并不能算个bug,并且这个问题在以前其他机器也没有出现,只是说在Lephone的这个平台上显示得不怎么美观,因为联想将原生的标题栏UI进行了修改.修改的过程中遇到了一个难题,系统自带的那个标题栏进度总能够到达100%后渐退,但是我每次最后到100%那一段显示不全,尝试了用线程程序死了卡主了不说,还是一样的效果,后来同事一句话提醒了我用动画.确实是这样我猜系统的也是这样实现的,等进度到达100%后,用动画改

  • 自定义log4j.properties的加载位置方式

    目录 自定义log4j.properties加载位置 方法一 方法二 方法三 log4j.properties自定义路径 在web.xml 下面配这些参数 自定义log4j.properties加载位置 方法一 在main函数中添加如下代码 public class App { static final Logger logger = Logger.getLogger(App.class); public static void main( String[] args ) { PropertyC

  • Android 常见的图片加载框架详细介绍

    Android 常见的图片加载框架 图片加载涉及到图片的缓存.图片的处理.图片的显示等.而随着市面上手机设备的硬件水平飞速发展,对图片的显示要求越来越高,稍微处理不好就会造成内存溢出等问题.很多软件厂家的通用做法就是借用第三方的框架进行图片加载. 开源框架的源码还是挺复杂的,但使用较为简单.大部分框架其实都差不多,配置稍微麻烦点,但是使用时一般只需要一行,显示方法一般会提供多个重载方法,支持不同需要.这样会减少很不必要的麻烦.同时,第三方框架的使用较为方便,这大大的减少了工作量.提高了开发效率.

  • 浅谈Tensorflow模型的保存与恢复加载

    近期做了一些反垃圾的工作,除了使用常用的规则匹配过滤等手段,也采用了一些机器学习方法进行分类预测.我们使用TensorFlow进行模型的训练,训练好的模型需要保存,预测阶段我们需要将模型进行加载还原使用,这就涉及TensorFlow模型的保存与恢复加载. 总结一下Tensorflow常用的模型保存方式. 保存checkpoint模型文件(.ckpt) 首先,TensorFlow提供了一个非常方便的api,tf.train.Saver()来保存和还原一个机器学习模型. 模型保存 使用tf.trai

  • pytorch加载自定义网络权重的实现

    在将自定义的网络权重加载到网络中时,报错: AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead. 我们一步一步分析. 模型网络权重保存额代码是:torch.sa

随机推荐