keras回调函数的使用

目录
  • 回调函数
    • fit()方法中使用callbacks参数
    • 模型的保存和加载
    • 通过对Callback类子类化来创建自定义回调函数
    • 【其他】模型的定义 和 数据加载

回调函数

  • 回调函数是一个对象(实现了特定方法的类实例),它在调用fit()时被传入模型,并在训练过程中的不同时间点被模型调用
  • 可以访问关于模型状态与模型性能的所有可用数据
  • 模型检查点(model checkpointing):在训练过程中的不同时间点保存模型的当前状态。
  • 提前终止(early stopping):如果验证损失不再改善,则中断训练(当然,同时保存在训练过程中的最佳模型)。
  • 在训练过程中动态调节某些参数值:比如调节优化器的学习率。
  • 在训练过程中记录训练指标和验证指标,或者将模型学到的表示可视化(这些表示在不断更新):fit()进度条实际上就是一个回调函数。

fit()方法中使用callbacks参数

# 这里有两个callback函数:早停和模型检查点
callbacks_list=[
    keras.callbacks.EarlyStopping(
        monitor="val_accuracy",#监控指标
        patience=2 #两轮内不再改善中断训练
    ),
    keras.callbacks.ModelCheckpoint(
        filepath="checkpoint_path",
        monitor="val_loss",
        save_best_only=True
    )
]
#模型获取
model=get_minist_model()
model.compile(optimizer="rmsprop",
             loss="sparse_categorical_crossentropy",
             metrics=["accuracy"])

model.fit(train_images,train_labels,
         epochs=10,callbacks=callbacks_list, #该参数使用回调函数
         validation_data=(val_images,val_labels))

test_metrics=model.evaluate(test_images,test_labels)#计算模型在新数据上的损失和指标
predictions=model.predict(test_images)#计算模型在新数据上的分类概率

模型的保存和加载

#也可以在训练完成后手动保存模型,只需调用model.save('my_checkpoint_path')。
#重新加载模型
model_new=keras.models.load_model("checkpoint_path.keras")

通过对Callback类子类化来创建自定义回调函数

on_epoch_begin(epoch, logs) ←----在每轮开始时被调用
on_epoch_end(epoch, logs) ←----在每轮结束时被调用
on_batch_begin(batch, logs) ←----在处理每个批量之前被调用
on_batch_end(batch, logs) ←----在处理每个批量之后被调用
on_train_begin(logs) ←----在训练开始时被调用
on_train_end(logs ←----在训练结束时被调用

from matplotlib import pyplot as plt
# 实现记录每一轮中每个batch训练后的损失,并为每个epoch绘制一个图
class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs):
        self.per_batch_losses = []

    def on_batch_end(self, batch, logs):
        self.per_batch_losses.append(logs.get("loss"))

    def on_epoch_end(self, epoch, logs):
        plt.clf()
        plt.plot(range(len(self.per_batch_losses)), self.per_batch_losses,
                 label="Training loss for each batch")
        plt.xlabel(f"Batch (epoch {epoch})")
        plt.ylabel("Loss")
        plt.legend()
        plt.savefig(f"plot_at_epoch_{epoch}")
        self.per_batch_losses = [] #清空,方便下一轮的技术
model = get_mnist_model()
model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])
model.fit(train_images, train_labels,
          epochs=10,
          callbacks=[LossHistory()],
          validation_data=(val_images, val_labels))

【其他】模型的定义 和 数据加载

def get_minist_model():
    inputs=keras.Input(shape=(28*28,))
    features=layers.Dense(512,activation="relu")(inputs)
    features=layers.Dropout(0.5)(features)
    outputs=layers.Dense(10,activation="softmax")(features)
    model=keras.Model(inputs,outputs)
    return model

#datset
from tensorflow.keras.datasets import mnist
(train_images,train_labels),(test_images,test_labels)=mnist.load_data()
train_images=train_images.reshape((60000,28*28)).astype("float32")/255
test_images=test_images.reshape((10000,28*28)).astype("float32")/255
train_images,val_images=train_images[10000:],train_images[:10000]
train_labels,val_labels=train_labels[10000:],train_labels[:10000]

到此这篇关于keras回调函数的使用的文章就介绍到这了,更多相关keras回调函数内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • keras 回调函数Callbacks 断点ModelCheckpoint教程

    整理自keras:https://keras-cn.readthedocs.io/en/latest/other/callbacks/ 回调函数Callbacks 回调函数是一个函数的合集,会在训练的阶段中所使用.你可以使用回调函数来查看训练模型的内在状态和统计.你可以传递一个列表的回调函数(作为 callbacks 关键字参数)到 Sequential 或 Model 类型的 .fit() 方法.在训练时,相应的回调函数的方法就会被在各自的阶段被调用. Callback keras.callb

  • keras自定义回调函数查看训练的loss和accuracy方式

    前言: keras是一个十分便捷的开发框架,为了更好的追踪网络训练过程中的损失函数loss和准确率accuracy,我们有几种处理方式,第一种是直接通过 history=model.fit(),来返回一个history对象,通过这个对象可以访问到训练过程训练集的loss和accuracy以及验证集的loss和accuracy. 第二种方式就是通过自定义一个回调函数Call backs,来实现这一功能,本文主要讲解第二种方式. 一.如何构建回调函数Callbacks 本文所针对的例子是卷积神经网络

  • 基于keras中的回调函数用法说明

    keras训练 fit( self, x, y, batch_size=32, nb_epoch=10, verbose=1, callbacks=[], validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None ) 1. x:输入数据.如果模型只有一个输入,那么x的类型是numpy array,如果模型有多个输入,那么x的类型应当为list,list的元素是对应

  • keras回调函数的使用

    目录 回调函数 fit()方法中使用callbacks参数 模型的保存和加载 通过对Callback类子类化来创建自定义回调函数 [其他]模型的定义 和 数据加载 回调函数 回调函数是一个对象(实现了特定方法的类实例),它在调用fit()时被传入模型,并在训练过程中的不同时间点被模型调用 可以访问关于模型状态与模型性能的所有可用数据 模型检查点(model checkpointing):在训练过程中的不同时间点保存模型的当前状态. 提前终止(early stopping):如果验证损失不再改善,

  • Kears 使用:通过回调函数保存最佳准确率下的模型操作

    1:首先,我给我的MixTest文件夹里面分好了类的图片进行重命名(因为分类的时候没有注意导致命名有点不好) def load_data(path): Rename the picture [a tool] for eachone in os.listdir(path): newname = eachone[7:] os.rename(path+"\\"+eachone,path+"\\"+newname) 但是需要注意的是:我们按照类重命名了以后,系统其实会按照图

  • PHP回调函数概念与用法实例分析

    本文实例讲述了PHP回调函数概念与用法.分享给大家供大家参考,具体如下: 一.回调函数的概念 先看一下C语言里的回调函数:回调函数就是一个通过函数指针调用的函数.如果你把函数的指针(地址)作为参数传递给另一个函数,当这个指针被用来调用其所指向的函数时,我们就说这是回调函数.回调函数不是由该函数的实现方直接调用,而是在特定的事件或条件发生时由另外的一方调用的,用于对该事件或条件进行响应. 其他语言里的回调函数的概念与之相似,只不过各种语言里回调函数的实现机制不一样,通俗的来说,回调函数是一个我们定

  • JS动态插入并立即执行回调函数的方法

    本文实例讲述了JS动态插入并立即执行回调函数的方法.分享给大家供大家参考,具体如下: <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"> <html xmlns="http://www.w3.org/1999/xhtml"> <

  • PHP回调函数与匿名函数实例详解

    本文实例讲述了PHP回调函数与匿名函数.分享给大家供大家参考,具体如下: 回调函数和匿名函数 回调函数.闭包在JS中并不陌生,JS使用它可以完成事件机制,进行许多复杂的操作.PHP中却不常使用,今天来说一说PHP中中的回调函数和匿名函数. 回调函数 回调函数:Callback (即call then back 被主函数调用运算后会返回主函数),是指通过函数参数传递到其它代码的,某一块可执行代码的引用. 通俗的解释就是把函数作为参数传入进另一个函数中使用:PHP中有许多 "需求参数为函数"

  • PHP中call_user_func_array回调函数的用法示例

    call_user_func_array call_user_func_array - 调用回调函数,并把一个数组参数作为回调函数的参数 mixed call_user_func_array ( callable $callback , array $param_arr ) 把第一个参数作为回调函数(callback)调用,把参数数组作(param_arr)为回调函数的的参数传入. 例子: function foobar($arg, $arg2) { echo __FUNCTION__, " g

  • PHP将回调函数作用到给定数组单元的方法

    数组是PHP程序设计中十分重要的一环.本文介绍PHP中数组函数array_map()的用法,实现将回调函数作用到给定数组单元上.具体如下: array array_map ( callable $callback , array $arr1 [, array $... ] ) array_map() 返回一个数组,该数组包含了 arr1 中的所有单元经过 callback 作用过之后的单元. callback 接受的参数数目应该和传递给 array_map() 函数的数组数目一致. 示例程序如下

随机推荐