tensorflow2.0保存和恢复模型3种方法

方法1:只保存模型的权重和偏置

这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创建和之前模型一模一样的模型,以保证权重和偏置的维度和保存之前的相同。

tf.keras.model类中的save_weights方法和load_weights方法,参数解释我就直接搬运官网的内容了。

save_weights(
 filepath,
 overwrite=True,
 save_format=None
)

Arguments:

filepath: String, path to the file to save the weights to. When saving in TensorFlow format, this is the prefix used for checkpoint files (multiple files are generated). Note that the '.h5' suffix causes weights to be saved in HDF5 format.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

save_format: Either 'tf' or 'h5'. A filepath ending in '.h5' or '.keras' will default to HDF5 if save_format is None. Otherwise None defaults to 'tf'.

load_weights(
 filepath,
 by_name=False
)

实例1:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers

# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()

# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])

# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )

# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))

# step6 保存模型的权重和偏置
model.save_weights('./save_weights/my_save_weights')

# step7 删除模型
del model

# step8 重新创建模型
model = create_model()
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])

# step9 恢复权重
model.load_weights('./save_weights/my_save_weights')

# step10 测试模型
loss, acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.55%

Restored model, accuracy:96.55%

可以看到在模型的权重和偏置恢复之后,在测试集合上同样达到了训练之前相同的准确率。

方法2:直接保存整个模型

这种方法会将网络的结构,权重和优化器的状态等参数全部保存下来,后期恢复的时候就没必要创建新的网络了。

tf.keras.model类中的save方法和load_model方法

save(
 filepath,
 overwrite=True,
 include_optimizer=True,
 save_format=None
)

Arguments:

filepath: String, path to SavedModel or H5 file to save the model.

overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.

include_optimizer: If True, save optimizer's state together.

save_format: Either 'tf' or 'h5', indicating whether to save the model to Tensorflow SavedModel or HDF5. The default is currently 'h5', but will switch to 'tf' in TensorFlow 2.0. The 'tf' option is currently disabled (use tf.keras.experimental.export_saved_modelinstead).

实例2:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers

# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# step2 创建模型
def create_model():
 return tf.keras.models.Sequential([
 tf.keras.layers.Flatten(input_shape=(28, 28)),
 tf.keras.layers.Dense(512, activation='relu'),
 tf.keras.layers.Dropout(0.2),
 tf.keras.layers.Dense(10, activation='softmax')
 ])
model = create_model()

# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])

# step4 模型训练 训练一个epochs
model.fit(x=x_train,
  y=y_train,
  epochs=1,
  )

# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))

# step6 保存模型的权重和偏置
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'

# step7 删除模型
del model # deletes the existing model

# step8 恢复模型
# returns a compiled model
# identical to the previous one
restored_model = tf.keras.models.load_model('my_model.h5')

# step9 测试模型
loss, acc = restored_model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.94%

Restored model, accuracy:96.94%

方法3:使用tf.keras.callbacks.ModelCheckpoint方法在训练过程中保存模型

该方法继承自tf.keras.callbacks类,一般配合mode.fit函数使用

以上这篇tensorflow2.0保存和恢复模型3种方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • tensorflow2.0保存和恢复模型3种方法

    方法1:只保存模型的权重和偏置 这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创建和之前模型一模一样的模型,以保证权重和偏置的维度和保存之前的相同. tf.keras.model类中的save_weights方法和load_weights方法,参数解释我就直接搬运官网的内容了. save_weights( filepath, overwrite=True, save_format=None ) Arguments: filepath: String,

  • 浅谈Tensorflow模型的保存与恢复加载

    近期做了一些反垃圾的工作,除了使用常用的规则匹配过滤等手段,也采用了一些机器学习方法进行分类预测.我们使用TensorFlow进行模型的训练,训练好的模型需要保存,预测阶段我们需要将模型进行加载还原使用,这就涉及TensorFlow模型的保存与恢复加载. 总结一下Tensorflow常用的模型保存方式. 保存checkpoint模型文件(.ckpt) 首先,TensorFlow提供了一个非常方便的api,tf.train.Saver()来保存和还原一个机器学习模型. 模型保存 使用tf.trai

  • tensorflow1.0学习之模型的保存与恢复(Saver)

    将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf.train.Saver() 在创建这个Saver对象的时候,有一个参数我们经常会用到,就是 max_to_keep 参数,这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型.如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置

  • 基于pytorch的保存和加载模型参数的方法

    当我们花费大量的精力训练完网络,下次预测数据时不想再(有时也不必再)训练一次时,这时候torch.save(),torch.load()就要登场了. 保存和加载模型参数有两种方式: 方式一: torch.save(net.state_dict(),path): 功能:保存训练完的网络的各层参数(即weights和bias) 其中:net.state_dict()获取各层参数,path是文件存放路径(通常保存文件格式为.pt或.pth) net2.load_state_dict(torch.loa

  • 配置SQL Server数据库恢复模式(2种方法)

    下面主要介绍配置SQL Server数据库恢复模式的两种方法. 用T-SQL设置恢复模式 你可以使用"ALTER DATABASE"命令加"SET RECOVERY"语句来修改数据库的恢复模式.例如,下面的查询语句把"AdventureWorks"数据库的恢复模式设置为完全恢复模式. ALTER DATABASE AdventureWorks SET RECOVERY FULL ; 你可以查询"sys.databases"的目

  • Oracle数据库数据丢失恢复的几种方法总结

    根据oracle数据库的特点和提供的工具,主要方法有以下几种方法: 利用逻辑备份使用import工具丢失数据的表 利用物理备份来通过还原数据文件并进行不完全恢复 利用dbms_logmnr包从redo log文件中恢复 利用flashback特性恢复数据 前提 为了方便使用方法的介绍,上述恢复方法都将基于以下场景进行:系统管理员在前一天晚上11点用export对数据库做了全库逻辑备份,然后对所有数据文件进行了热备份.第二天上午10点,系统管理员在修改表TFUNDASSET的数据时,由于修改语句的

  • SQL Server无日志恢复数据库(2种方法)

    SQL Server是一个关系数据库管理系统,应用很广泛,在进行SQL Server数据库操作的过程中难免会出现误删或者别的原因引起的日志损坏,又由于SQL Server数据库中数据的重要性,出现了以上的故障之后就必须对数据库中数据进行恢复.下文就为大家介绍一种恢复数据库日志文件的方法. 解决方法一 1.新建一个同名的数据库 2.再停掉sql server(注意不要分离数据库) 3.用原数据库的数据文件覆盖掉这个新建的数据库 4.再重启sql server 5.此时打开企业管理器时会出现置疑,先

  • 详解Springboot 优雅停止服务的几种方法

    在使用Springboot的时候,都要涉及到服务的停止和启动,当我们停止服务的时候,很多时候大家都是kill -9 直接把程序进程杀掉,这样程序不会执行优雅的关闭.而且一些没有执行完的程序就会直接退出. 我们很多时候都需要安全的将服务停止,也就是把没有处理完的工作继续处理完成.比如停止一些依赖的服务,输出一些日志,发一些信号给其他的应用系统,这个在保证系统的高可用是非常有必要的.那么咱么就来看一下几种停止springboot的方法. 第一种就是Springboot提供的actuator的功能,它

  • python运行时间的几种方法

    最早见过手写的,类似于下面这种: import datetime def time_1(): begin = datetime.datetime.now() sum = 0 for i in xrange(10000000): sum = sum + i end = datetime.datetime.now() return end-begin print time_1() 输出如下: ➜  Python python time_1.py 0:00:00.280797 另外一种方法是使用tim

  • Android 杀死进程几种方法详细介绍

    Android 杀死进程: 对于进程结束在开发APP应用当中还是有必要的,这里整理了三种方法,大家可以根据需求选用. 当应用不再使用时,通常需要关闭应用,可以使用以下三种方法关闭android应用: 第一种方法:首先获取当前进程的id,然后杀死该进程. android.os.Process.killProcess(android.os.Process.myPid()) 接下来实践一下: <RelativeLayout xmlns:android="http://schemas.androi

  • 详解mybatis 批量更新数据两种方法效率对比

    上节探讨了批量新增数据,这节探讨批量更新数据两种写法的效率问题. 实现方式有两种, 一种用for循环通过循环传过来的参数集合,循环出N条sql, 另一种 用mysql的case when 条件判断变相的进行批量更新 下面进行实现. 注意第一种方法要想成功,需要在db链接url后面带一个参数  &allowMultiQueries=true 即:  jdbc:mysql://localhost:3306/mysqlTest?characterEncoding=utf-8&allowMulti

随机推荐