如何在keras中添加自己的优化器(如adam等)

本文主要讨论windows下基于tensorflow的keras

1、找到tensorflow的根目录

如果安装时使用anaconda且使用默认安装路径,则在 C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow处可以找到(此处为GPU版本),cpu版本可在C:\ProgramData\Anaconda3\Lib\site-packages\tensorflow处找到。若并非使用默认安装路径,可参照根目录查看找到。

2、找到keras在tensorflow下的根目录

需要特别注意的是找到keras在tensorflow下的根目录而不是找到keras的根目录。一般来说,完成tensorflow以及keras的配置后即可在tensorflow目录下的python目录中找到keras目录,以GPU为例keras在tensorflow下的根目录为C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow\python\keras

3、找到keras目录下的optimizers.py文件并添加自己的优化器

找到optimizers.py中的adam等优化器类并在后面添加自己的优化器类

以本文来说,我在第718行添加如下代码

@tf_export('keras.optimizers.adamsss')
class Adamsss(Optimizer):

 def __init__(self,
  lr=0.002,
  beta_1=0.9,
  beta_2=0.999,
  epsilon=None,
  schedule_decay=0.004,
  **kwargs):
 super(Adamsss, self).__init__(**kwargs)
 with K.name_scope(self.__class__.__name__):
 self.iterations = K.variable(0, dtype='int64', name='iterations')
 self.m_schedule = K.variable(1., name='m_schedule')
 self.lr = K.variable(lr, name='lr')
 self.beta_1 = K.variable(beta_1, name='beta_1')
 self.beta_2 = K.variable(beta_2, name='beta_2')
 if epsilon is None:
 epsilon = K.epsilon()
 self.epsilon = epsilon
 self.schedule_decay = schedule_decay

 def get_updates(self, loss, params):
 grads = self.get_gradients(loss, params)
 self.updates = [state_ops.assign_add(self.iterations, 1)]

 t = math_ops.cast(self.iterations, K.floatx()) + 1

 # Due to the recommendations in [2], i.e. warming momentum schedule
 momentum_cache_t = self.beta_1 * (
 1. - 0.5 *
 (math_ops.pow(K.cast_to_floatx(0.96), t * self.schedule_decay)))
 momentum_cache_t_1 = self.beta_1 * (
 1. - 0.5 *
 (math_ops.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay)))
 m_schedule_new = self.m_schedule * momentum_cache_t
 m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
 self.updates.append((self.m_schedule, m_schedule_new))

 shapes = [K.int_shape(p) for p in params]
 ms = [K.zeros(shape) for shape in shapes]
 vs = [K.zeros(shape) for shape in shapes]

 self.weights = [self.iterations] + ms + vs

 for p, g, m, v in zip(params, grads, ms, vs):
 # the following equations given in [1]
 g_prime = g / (1. - m_schedule_new)
 m_t = self.beta_1 * m + (1. - self.beta_1) * g
 m_t_prime = m_t / (1. - m_schedule_next)
 v_t = self.beta_2 * v + (1. - self.beta_2) * math_ops.square(g)
 v_t_prime = v_t / (1. - math_ops.pow(self.beta_2, t))
 m_t_bar = (
  1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime

 self.updates.append(state_ops.assign(m, m_t))
 self.updates.append(state_ops.assign(v, v_t))

 p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon)
 new_p = p_t

 # Apply constraints.
 if getattr(p, 'constraint', None) is not None:
 new_p = p.constraint(new_p)

 self.updates.append(state_ops.assign(p, new_p))
 return self.updates

 def get_config(self):
 config = {
 'lr': float(K.get_value(self.lr)),
 'beta_1': float(K.get_value(self.beta_1)),
 'beta_2': float(K.get_value(self.beta_2)),
 'epsilon': self.epsilon,
 'schedule_decay': self.schedule_decay
 }
 base_config = super(Adamsss, self).get_config()
 return dict(list(base_config.items()) + list(config.items()))

然后修改之后的优化器调用类添加我自己的优化器adamss

需要修改的有(下面的两处修改依旧在optimizers.py内)

# Aliases.

sgd = SGD
rmsprop = RMSprop
adagrad = Adagrad
adadelta = Adadelta
adam = Adam
adamsss = Adamsss
adamax = Adamax
nadam = Nadam

以及

def deserialize(config, custom_objects=None):
 """Inverse of the `serialize` function.

 Arguments:
 config: Optimizer configuration dictionary.
 custom_objects: Optional dictionary mapping
  names (strings) to custom objects
  (classes and functions)
  to be considered during deserialization.

 Returns:
 A Keras Optimizer instance.
 """
 if tf2.enabled():
 all_classes = {
 'adadelta': adadelta_v2.Adadelta,
 'adagrad': adagrad_v2.Adagrad,
 'adam': adam_v2.Adam,
		'adamsss': adamsss_v2.Adamsss,
 'adamax': adamax_v2.Adamax,
 'nadam': nadam_v2.Nadam,
 'rmsprop': rmsprop_v2.RMSprop,
 'sgd': gradient_descent_v2.SGD
 }
 else:
 all_classes = {
 'adadelta': Adadelta,
 'adagrad': Adagrad,
 'adam': Adam,
 'adamax': Adamax,
 'nadam': Nadam,
		'adamsss': Adamsss,
 'rmsprop': RMSprop,
 'sgd': SGD,
 'tfoptimizer': TFOptimizer
 }

这里我们并没有v2版本,所以if后面的部分不改也可以。

4、调用我们的优化器对模型进行设置

model.compile(loss = 'crossentropy', optimizer = 'adamss', metrics=['accuracy'])

5、训练模型

train_history = model.fit(x, y_label, validation_split = 0.2, epoch = 10, batch = 128, verbose = 1)

补充知识:keras设置学习率--优化器的用法

优化器的用法

优化器 (optimizer) 是编译 Keras 模型的所需的两个参数之一:

from keras import optimizers

model = Sequential()
model.add(Dense(64, kernel_initializer='uniform', input_shape=(10,)))
model.add(Activation('softmax'))

sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='mean_squared_error', optimizer=sgd)

你可以先实例化一个优化器对象,然后将它传入 model.compile(),像上述示例中一样, 或者你可以通过名称来调用优化器。在后一种情况下,将使用优化器的默认参数。

# 传入优化器名称: 默认参数将被采用
model.compile(loss='mean_squared_error', optimizer='sgd')

以上这篇如何在keras中添加自己的优化器(如adam等)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • 使用Keras中的ImageDataGenerator进行批次读图方式

    ImageDataGenerator位于keras.preprocessing.image模块当中,可用于做数据增强,或者仅仅用于一个批次一个批次的读进图片数据.一开始以为ImageDataGenerator是用来做数据增强的,但我的目的只是想一个batch一个batch的读进图片而已,所以一开始没用它,后来发现它是有这个功能的,而且使用起来很方便. ImageDataGenerator类包含了如下参数:(keras中文教程) ImageDataGenerator(featurewise_cen

  • 浅谈keras通过model.fit_generator训练模型(节省内存)

    前言 前段时间在训练模型的时候,发现当训练集的数量过大,并且输入的图片维度过大时,很容易就超内存了,举个简单例子,如果我们有20000个样本,输入图片的维度是224x224x3,用float32存储,那么如果我们一次性将全部数据载入内存的话,总共就需要20000x224x224x3x32bit/8=11.2GB 这么大的内存,所以如果一次性要加载全部数据集的话是需要很大内存的. 如果我们直接用keras的fit函数来训练模型的话,是需要传入全部训练数据,但是好在提供了fit_generator,

  • 浅谈keras2 predict和fit_generator的坑

    1.使用predict时,必须设置batch_size,否则效率奇低. 查看keras文档中,predict函数原型: predict(self, x, batch_size=32, verbose=0) 说明: 只使用batch_size=32,也就是说每次将batch_size=32的数据通过PCI总线传到GPU,然后进行预测.在一些问题中,batch_size=32明显是非常小的.而通过PCI传数据是非常耗时的. 所以,使用的时候会发现预测数据时效率奇低,其原因就是batch_size太小

  • 在keras中model.fit_generator()和model.fit()的区别说明

    首先Keras中的fit()函数传入的x_train和y_train是被完整的加载进内存的,当然用起来很方便,但是如果我们数据量很大,那么是不可能将所有数据载入内存的,必将导致内存泄漏,这时候我们可以用fit_generator函数来进行训练. keras中文文档 fit fit(x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=N

  • 如何在keras中添加自己的优化器(如adam等)

    本文主要讨论windows下基于tensorflow的keras 1.找到tensorflow的根目录 如果安装时使用anaconda且使用默认安装路径,则在 C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow处可以找到(此处为GPU版本),cpu版本可在C:\ProgramData\Anaconda3\Lib\site-packages\tensorflow处找到.若并非使用默认安装路径,可参照根目

  • Keras SGD 随机梯度下降优化器参数设置方式

    SGD 随机梯度下降 Keras 中包含了各式优化器供我们使用,但通常我会倾向于使用 SGD 验证模型能否快速收敛,然后调整不同的学习速率看看模型最后的性能,然后再尝试使用其他优化器. Keras 中文文档中对 SGD 的描述如下: keras.optimizers.SGD(lr=0.01, momentum=0.0, decay=0.0, nesterov=False) 随机梯度下降法,支持动量参数,支持学习衰减率,支持Nesterov动量 参数: lr:大或等于0的浮点数,学习率 momen

  • 如何在django中添加日志功能

    官方文档 猛戳这里 在settings中配置以下代码 #LOGGING_DIR 日志文件存放目录 LOGGING_DIR = "logs" # 日志存放路径 if not os.path.exists(LOGGING_DIR): os.mkdir(LOGGING_DIR) import logging LOGGING = { 'version': 1, 'disable_existing_loggers': False, 'formatters': { #格式化器 'standard'

  • 如何在postman中添加cookie信息步骤解析

    在测试工作中,很多的接口都依赖于登录接口,即在调用该接口前必须有登录的信息,否则调用会报错,那如何在postman中添加cookie信息呢?主要分为两个步骤,下面为大家详细介绍: 第一步:我们首先使用postman访问登录接口,在response中找到返回的cookie信息,并拷贝: 第二步:这里我们为大家介绍两种方法 方法1:在请求接口的headers中将cookie值设置为拷贝的cookie信息: 方法2:将cookie设置为全局变量,请求接口时直接调用即可. 全局变量的配置如下: 以上就是

  • Java如何在PDF中添加ToolTip工具提示

    目录 前言 导入jar包 添加工具提示ToolTip 总结 前言 本文,将介绍如何通过Java后端程序代码在PDF中创建工具提示.添加工具提示后,当鼠标悬停在页面上的元素时,将显示工具提示内容. 导入jar包 本次程序中使用的是Free Spire.PDF for Java,具体导入jar文件的方法参考如下内容. 两种方法可导入jar到程序: 方法1. 通过Maven仓库下载导入.在pom.xml配置:​ <repositories> <repository> <id>

  • 教你如何在IDEA 中添加 Maven 项目的 Archetype(解决添加不起作用的问题)

    目录 前言 实现过程 新建模块 添加脚手架 前言 在 IDEA 中点击新建 Maven 模块,会发现他已经为我们罗列出来了许多的 archetype,但有些时候满足不了我们的需求.下面就来看看如何添加自己的脚手架吧. 实现过程 新建模块 在 IDEA 中新建一个模块,需要保证每个目录下都至少有一个文件,不然打包的时候那个文件夹会被忽略掉,这里使用的项目结构如下图所示: 1.创建脚手架并打包打开终端,cd 到这个模块的根目录,比如这里是 D:/Java_Study/idea_projects/sp

  • 详解如何在nuxt中添加proxyTable代理

    背景 在本地开发vue项目的时候,当你习惯了proxyTable解决本地跨域的问题,切换到nuxt的时候,你会发现,添加了proxyTable设置并没有什么作用,那是因为你是用的vue脚手架生成的vue项目,它里面已经帮你写好了相关的proxyTable的设置代码. build/dev-server.js // proxy api requests Object.keys(proxyTable).forEach(function (context) { var options = proxyTa

  • Yii2 如何在modules中添加验证码的方法

    最近玩了下Yii2的验证码部分,正常的逻辑都可以走通的,网上的例子也是没有问题的,关键有问题的部分是在module中使用的时候,分享给大家,往下看之前可以去看看正常情况下是如何使用的. controller部分的代码,这里的跟网上的都类似 public function actions() { return [ 'captcha' => [ 'class' => 'yii\captcha\CaptchaAction', 'fixedVerifyCode' => null, 'backCo

  • 如何在Django中添加没有微秒的 DateTimeField 属性详解

    前言 今天在项目中遇到一个Django的大坑,一个很简单的分页问题,造成了数据重复.最后排查发现是DateTimeField 属性引起的. 下面描述下问题,下面是我需要用到的一个 Task Model 基本定义: class Task(models.Model): # ...... 省略了其他字段 title = models.CharField(max_length=256, verbose_name=u'标题') created_at = models.DateTimeField(auto_

  • 在spring-boot工程中添加spring mvc拦截器

    1. 认识拦截器 Spring MVC的拦截器(Interceptor)不是Filter,同样可以实现请求的预处理.后处理.使用拦截器仅需要两个步骤: 实现拦截器 注册拦截器 1.1 实现拦截器 实现拦截器可以自定义实现HandlerInterceptor接口,也可以通过继承HandlerInterceptorAdapter类,后者是前者的实现类.下面是拦截器的一个实现的例子,目的是判断用户是否登录.如果preHandle方法return true,则继续后续处理. public class L

随机推荐