详解tensorflow实现迁移学习实例

本文主要是总结利用tensorflow实现迁移学习的基本步骤。

所谓迁移学习,就是将上一个问题上训练好的模型通过简单的调整使其适用于一个新的问题。比如说,我们可以保留训练好的Inception-v3模型中所有的参数,只替换最后一层全连接层。在最后一层全连接层之前的网络称之为瓶颈层(bottleneck)。

持久化

首先需要简单介绍下tensorflow中的持久化:在tensorflow中提供了一个非常简单的API来保存和还原一个神经网络模型,这个API就是tf.train.Saver类。当采用该方法保存时会生成三个文件,一个文件是model.ckpt.meta,它保存了Tensorflow计算图的结构;第二个文件是model.ckpt,它保存了程序中每一个变量的取值;最后一个文件是checkpoint文件,这个文件中保存了一个目录下所有模型文件列表。

保存图

init_op = tf.initialize_all_variables()
with tf.Session() as sess:
  sess.run(init_op)
  saver.save(sess, "model.ckpt")

加载图

saver = tf.train.import_meta_graph("model.ckpt.meta")
with tf.Session() as sess:
  saver.restore(sess, "model.ckpt")

迁移学习

第一步: 读取加载已经训练好的模型

在inception-v3模型代表瓶颈层结果的张量名称是'pool3/_reshape:0',图像输入张量对应的名称'DecodeJpeg/contents:0'

BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'
#读取已经训练好的模型
  with gfile.FastGFile(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
  bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def(graph_def, return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])

第二步:利用读取的模型,定义新的神经网络输入,这个输入就是新的图片经过Inception-v3模型前向传播到达瓶颈层的取值,是一种特征提取过程。

def run_bottlenect_on_images(sess, image_data, image_data_tensor, bottlenect_tensor):
  bottlenect_values = sess.run(bottlenect_tensor, {image_data_tensor: image_data})

  # 经过卷积网络处理后的是一个思维数组,压缩成一个特征,一维向量输出
  bottlenect_values = np.squeeze(bottlenect_values)
  return bottlenect_values

该过程实际上利用获取的tensor计算图片的特征向量,完成特征提取的过程。

第三步:利用获取的图像的特征向量完成接下来的任务(比如分类)

以上是仅关键代码。希望对大家的学习有所帮助,也希望大家多多支持我们。

您可能感兴趣的文章:

  • 初探TensorFLow从文件读取图片的四种方式
  • 用tensorflow实现弹性网络回归算法
  • Tensorflow 利用tf.contrib.learn建立输入函数的方法
  • 基于ubuntu16 Python3 tensorflow(TensorFlow环境搭建)
  • TensorFlow实现创建分类器
  • TensorFlow高效读取数据的方法示例
  • TensorFlow如何实现反向传播
(0)

相关推荐

  • TensorFlow实现创建分类器

    本文实例为大家分享了TensorFlow实现创建分类器的具体代码,供大家参考,具体内容如下 创建一个iris数据集的分类器. 加载样本数据集,实现一个简单的二值分类器来预测一朵花是否为山鸢尾.iris数据集有三类花,但这里仅预测是否是山鸢尾.导入iris数据集和工具库,相应地对原数据集进行转换. # Combining Everything Together #---------------------------------- # This file will perform binary c

  • 用tensorflow实现弹性网络回归算法

    本文实例为大家分享了tensorflow实现弹性网络回归算法,供大家参考,具体内容如下 python代码: #用tensorflow实现弹性网络算法(多变量) #使用鸢尾花数据集,后三个特征作为特征,用来预测第一个特征. #1 导入必要的编程库,创建计算图,加载数据集 import matplotlib.pyplot as plt import tensorflow as tf import numpy as np from sklearn import datasets from tensor

  • TensorFlow如何实现反向传播

    使用TensorFlow的一个优势是,它可以维护操作状态和基于反向传播自动地更新模型变量. TensorFlow通过计算图来更新变量和最小化损失函数来反向传播误差的.这步将通过声明优化函数(optimization function)来实现.一旦声明好优化函数,TensorFlow将通过它在所有的计算图中解决反向传播的项.当我们传入数据,最小化损失函数,TensorFlow会在计算图中根据状态相应的调节变量. 回归算法的例子从均值为1.标准差为0.1的正态分布中抽样随机数,然后乘以变量A,损失函

  • 初探TensorFLow从文件读取图片的四种方式

    本文记录一下TensorFLow的几种图片读取方法,官方文档有较为全面的介绍. 1.使用gfile读图片,decode输出是Tensor,eval后是ndarray import matplotlib.pyplot as plt import tensorflow as tf import numpy as np print(tf.__version__) image_raw = tf.gfile.FastGFile('test/a.jpg','rb').read() #bytes img =

  • TensorFlow高效读取数据的方法示例

    概述 最新上传的mcnn中有完整的数据读写示例,可以参考. 关于Tensorflow读取数据,官网给出了三种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据. 从文件读取数据: 在TensorFlow图的起始, 让一个输入管线从文件中读取数据. 预加载数据: 在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况). 对于数据量较小而言,可能一般选择直接将数据加载进内存,然后再分batch输入网络进行训练(t

  • 基于ubuntu16 Python3 tensorflow(TensorFlow环境搭建)

    人最大的长处就是有厉害的大脑.电脑.手机等都是对人大脑的拓展.现今,我们每个人都有这个机会,让自己头脑在智能的帮助下,达到极高的高度.所以,拥抱科技,让智能产品成为我们个人智力的拓展,更好的去生活.去战斗. 用项目引导学习: 我们的目标是用现有最流行的谷歌开源框架TensorFlow,搭建一款儿童助学帮手.类似于现在已有的在售商品小米智能语音盒子之类的东西,. 一.Windows下安装虚拟机VMware Workstation,在虚拟机中安装Ubuntu(要善用搜索引擎,解决各类简单问题) 这里

  • Tensorflow 利用tf.contrib.learn建立输入函数的方法

    在实际的业务中,可能会遇到很大量的特征,这些特征良莠不齐,层次不一,可能有缺失,可能有噪声,可能规模不一致,可能类型不一样,等等问题都需要我们在建模之前,先预处理特征或者叫清洗特征.那么这清洗特征的过程可能涉及多个步骤可能比较复杂,为了代码的简洁,我们可以将所有的预处理过程封装成一个函数,然后直接往模型中传入这个函数就可以啦~~~ 接下来我们看看究竟如何做呢? 1. 如何使用input_fn自定义输入管道 当使用tf.contrib.learn来训练一个神经网络时,可以将特征,标签数据直接输入到

  • 详解tensorflow实现迁移学习实例

    本文主要是总结利用tensorflow实现迁移学习的基本步骤. 所谓迁移学习,就是将上一个问题上训练好的模型通过简单的调整使其适用于一个新的问题.比如说,我们可以保留训练好的Inception-v3模型中所有的参数,只替换最后一层全连接层.在最后一层全连接层之前的网络称之为瓶颈层(bottleneck). 持久化 首先需要简单介绍下tensorflow中的持久化:在tensorflow中提供了一个非常简单的API来保存和还原一个神经网络模型,这个API就是tf.train.Saver类.当采用该

  • 用十张图详解TensorFlow数据读取机制(附代码)

    在学习TensorFlow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找不到什么合适的学习材料.今天这篇文章就以图片的形式,用最简单的语言,为大家详细解释一下TensorFlow的数据读取机制,文章的最后还会给出实战代码以供参考. TensorFlow读取机制图解 首先需要思考的一个问题是,什么是数据读取?以图像数据为例,读取数据的过程可以用下图来表示: 假设我们的硬盘中有一个图片数据集0001.jpg,0002.jpg,0003.jpg--我们只需要把

  • 详解TensorFlow在windows上安装与简单示例

    本文介绍了详解TensorFlow在windows上安装与简单示例,分享给大家,具体如下: 安装说明 平台:目前可在Ubuntu.Mac OS.Windows上安装 版本:提供gpu版本.cpu版本 安装方式:pip方式.Anaconda方式 Tips: 在Windows上目前支持python3.5.x gpu版本需要cuda8,cudnn5.1 安装进度 2017/3/4进度: Anaconda 4.3(对应python3.6)正在安装,又删除了,一无所有了 2017/3/5进度: Anaco

  • 详解tensorflow之过拟合问题实战

    过拟合问题实战 1.构建数据集 我们使用的数据集样本特性向量长度为 2,标签为 0 或 1,分别代表了 2 种类别.借助于 scikit-learn 库中提供的 make_moons 工具我们可以生成任意多数据的训练集. import matplotlib.pyplot as plt # 导入数据集生成工具 import numpy as np import seaborn as sns from sklearn.datasets import make_moons from sklearn.m

  • 详解表单验证正则表达式实例(推荐)

    验证:!reg.test(value) 邮箱: 复制代码 代码如下: reg = /^\w+((-\w+)|(\.\w+))*\@[A-Za-z0-9]+((\.|-)[A-Za-z0-9]+)*\.[A-Za-z0-9]+$/i; 不包含中文: 复制代码 代码如下: reg = /.*[\u4e00-\u9fa5]+.*$/i; 身份证号: // 验证身份证号码 var city = {11:'北京',12:'天津',13:'河北',14:'山西',15:'内蒙古',21:'辽宁',22:'吉

  • 详解Tensorflow不同版本要求与CUDA及CUDNN版本对应关系

    参考官网地址: Windows端:https://tensorflow.google.cn/install/source_windows CPU Version Python version Compiler Build tools tensorflow-1.11.0 3.5-3.6 MSVC 2015 update 3 Cmake v3.6.3 tensorflow-1.10.0 3.5-3.6 MSVC 2015 update 3 Cmake v3.6.3 tensorflow-1.9.0

  • 详解vite2.0配置学习(typescript版本)

    介绍 尤于溪的原话. vite与 Vue CLI 类似,vite 也是一个提供基本项目脚手架和开发服务器的构建工具. vite基于浏览器原生ES imports的开发服务器.跳过打包这个概念,服务端按需编译返回. vite速度比webpack快10+倍,支持热跟新, 但是出于处于测试阶段. 配置文件也支持热跟新!!! 创建 执行npm init @vitejs/app ,我这里选择的是vue-ts 版本 "vite": "^2.0.0-beta.48" alias别

  • 详解TensorFlow训练网络两种方式

    TensorFlow训练网络有两种方式,一种是基于tensor(array),另外一种是迭代器 两种方式区别是: 第一种是要加载全部数据形成一个tensor,然后调用model.fit()然后指定参数batch_size进行将所有数据进行分批训练 第二种是自己先将数据分批形成一个迭代器,然后遍历这个迭代器,分别训练每个批次的数据 方式一:通过迭代器 IMAGE_SIZE = 1000 # step1:加载数据集 (train_images, train_labels), (val_images,

  • 一文详解Java过滤器拦截器实例逐步掌握

    目录 一.过滤器与拦截器相同点 二.过滤器与拦截器区别 三.过滤器与拦截器的实现 四.过滤器与拦截器相关面试题 一.过滤器与拦截器相同点 1.拦截器与过滤器都是体现了AOP的思想,对方法实现增强,都可以拦截请求方法. 2.拦截器和过滤器都可以通过Order注解设定执行顺序 二.过滤器与拦截器区别 在Java Web开发中,过滤器(Filter)和拦截器(Interceptor)都是常见的用于在请求和响应之间进行处理的组件.它们的主要区别如下: 运行位置不同:过滤器是运行在Web服务器和Servl

  • 详解Vue源码学习之双向绑定

    原理 当你把一个普通的 JavaScript 对象传给 Vue 实例的 data 选项,Vue 将遍历此对象所有的属性,并使用 Object.defineProperty 把这些属性全部转为 getter/setter.Object.defineProperty 是 ES5 中一个无法 shim 的特性,这也就是为什么 Vue 不支持 IE8 以及更低版本浏览器. 上面那段话是Vue官方文档中截取的,可以看到是使用Object.defineProperty实现对数据改变的监听.Vue主要使用了观

随机推荐