浅谈Tensorflow加载Vgg预训练模型的几个注意事项

写这个博客的关键Bug: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64。本博客将围绕 加载图片 和 保存图片到本地 来详细解释和解决上述的Bug及其引出来的一系列Bug。

加载图片

首先,造成上述Bug的代码如下所示

image_path = "data/test.jpg" # 本地的测试图片

image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.image.decode_jpeg(image_raw)

# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)

# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)

上述代码是加载Vgg19预训练模型,并传入图片得到所有层的特征图,具体的代码实现和原理讲解可参考我的另一篇博客:Tensorflow加载Vgg预训练模型。那么,为什么代码会出现: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64,这个Bug呢?

这句英文翻译过来是指:传递的值类型是uint8,但是接受的参数类型必须是float的那几种。故原因就是传入值的数据类型错了,那么如何解决这个Bug呢,很简单

image_path = "data/test.jpg" # 本地的测试图片

image_raw = tf.gfile.GFile(image_path, 'rb').read()
# 一定要tf.float(),否则会报错
image_decoded = tf.to_float(tf.image.decode_jpeg(image_raw))

# 扩展图片的维度,从三维变成四维,符合Vgg19的输入接口
image_expand_dim = tf.expand_dims(image_decoded, 0)

# 定义Vgg19模型
vgg19 = VGG19(data_path)
net = vgg19.feed_forward(image_expand_dim, 'vgg19')
print(net)

这两个代码块唯一的变动就是:image_decoded结果在输出前加了一个tf.float(),将其转换为float类型。

在tensorflow API中,tf.image.decode_jpeg()默认读取的图片数据格式为unit8,而不是float。uint8数据的范围在(0, 255)中,正好符合图片的像素范围(0, 255)。但是,保存在本地的Vgg19预训练模型的数据接口为float,所以才造成了本文开头的Bug。

这里还要提一点,若是使用PIL的方法来加载图片,则不会出现上述的Bug,因为通过PIL得到的图片格式是float,而不是uint8,故不需要转换。

很多同学可能会疑惑,若是强行改变了原图片的数据格式,从uint8类型转变成float,会不会导致数据改变或者出错?故我做了下面这个实验:

image_path = "data/3.jpg"
image_raw = tf.gfile.GFile(image_path, 'rb').read()
image_unit8 = tf.image.decode_jpeg(image_raw)
image_float = tf.to_float(image_unit8)

with tf.Session() as sess:
 image_unit8_, image_float_ = sess.run([image_unit8, image_float])

print("image_unit8_", image_unit8_)
print("image_float_ ", image_float_ )

代码结果如下:

 image_unit8_
 [180, 192, 204],
 [183, 195, 207],
 [186, 198, 210],
 ...,
 [191, 205, 218],
 [191, 205, 218],
 [190, 204, 217]],

 image_float_
 [180., 192., 204.],
 [183., 195., 207.],
 [186., 198., 210.],
 ...,
 [191., 205., 218.],
 [191., 205., 218.],
 [190., 204., 217.]],

可以看到,数据根本没有变化,只是后面多加了个小数点,变得只有类型,而没有强制改变值,故同学们不需要过度担心。

保存图片到本地

在加载图片的时候,为了使用保存在本地的预训练Vgg19模型,我们需要将读取的图片由uint8格式转换成float格式。那若是我们想将已经转换为float格式的图片再保存到本地,该怎么做呢?

首先,我们根据上述的文字的意思读取图片,并且将其转换为float格式,在将读取的图片再次保存到本地之前,我们首先可视化一下转换格式后的图片,代码如下:

import tensorflow as tf
from matplotlib import pyplot as plt
image_path = "data/boat.jpg"

image_raw = tf.gfile.GFile(image_path, 'rb').read()
image_decoded = tf.image.decode_jpeg(image_raw)
image_decoded = tf.to_float(image_decoded)

with tf.Session() as sess:
 image_decoded_ = sess.run(image_decoded)
 plt.imshow(image_decoded_)
 plt.show()

生成的图片如下图所示:

左边是原图,右边是转换为float格式的图片,可见将图片转换为float格式,虽然数值没有造成太大影响,但是若想将图片保存到本地就会出现问题。

说了这么多,只为了说一点,在保存图片到本地之前,需要将其格式从float转回uint8,否则会造成一系列错误:图片显示异常,API报错等。正确的保存代码如下:

save_path = "data/boat_copy.jpg"
image_uint = tf.cast(image_decoded, tf.uint8)
with tf.Session() as sess:
 with open(save_path, 'wb') as img:
 image_saved = sess.run(tf.image.encode_jpeg(image_uint))
 img.write(image_saved)

其中只有一句话最关键,即 tf.cast(image_decoded, tf.uint8)。

以上这篇浅谈Tensorflow加载Vgg预训练模型的几个注意事项就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • TensorFlow keras卷积神经网络 添加L2正则化方式

    我就废话不多说了,大家还是直接看代码吧! model = keras.models.Sequential([ #卷积层1 keras.layers.Conv2D(32,kernel_size=5,strides=1,padding="same",data_format="channels_last",activation=tf.nn.relu,kernel_regularizer=keras.regularizers.l2(0.01)), #池化层1 keras.l

  • Keras使用ImageNet上预训练的模型方式

    我就废话不多说了,大家还是直接看代码吧! import keras import numpy as np from keras.applications import vgg16, inception_v3, resnet50, mobilenet #Load the VGG model vgg_model = vgg16.VGG16(weights='imagenet') #Load the Inception_V3 model inception_model = inception_v3.I

  • Tensorflow tf.tile()的用法实例分析

    tf.tile()应用于需要张量扩展的场景,具体说来就是: 如果现有一个形状如[width, height]的张量,需要得到一个基于原张量的,形状如[batch_size,width,height]的张量,其中每一个batch的内容都和原张量一模一样.tf.tile使用方法如: tile( input, multiples, name=None ) import tensorflow as tf a = tf.constant([7,19]) a1 = tf.tile(a,multiples=[

  • keras模型保存为tensorflow的二进制模型方式

    最近需要将使用keras训练的模型移植到手机上使用, 因此需要转换到tensorflow的二进制模型. 折腾一下午,终于找到一个合适的方法,废话不多说,直接上代码: # coding=utf-8 import sys from keras.models import load_model import tensorflow as tf import os import os.path as osp from keras import backend as K def freeze_session

  • 浅谈Tensorflow加载Vgg预训练模型的几个注意事项

    写这个博客的关键Bug: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64.本博客将围绕 加载图片 和 保存图片到本地 来详细解释和解决上述的Bug及其引出来的一系列Bug. 加载图片 首先,造成上述Bug的代码如下所示 image_path = "data/test.jpg" # 本地的测试图片

  • Tensorflow加载Vgg预训练模型操作

    很多深度神经网络模型需要加载预训练过的Vgg参数,比如说:风格迁移.目标检测.图像标注等计算机视觉中常见的任务.那么到底如何加载Vgg模型呢?Vgg文件的参数到底有何意义呢?加载后的模型该如何使用呢? 本文将以Vgg19为例子,详细说明Tensorflow如何加载Vgg预训练模型. 实验环境 GTX1050-ti, cuda9.0 Window10, Tensorflow 1.12 展示Vgg19构造 import tensorflow as tf import numpy as np impo

  • Pytorch加载部分预训练模型的参数实例

    前言 自从从深度学习框架caffe转到Pytorch之后,感觉Pytorch的优点妙不可言,各种设计简洁,方便研究网络结构修改,容易上手,比TensorFlow的臃肿好多了.对于深度学习的初学者,Pytorch值得推荐.今天主要主要谈谈Pytorch是如何加载预训练模型的参数以及代码的实现过程. 直接加载预选脸模型 如果我们使用的模型和预训练模型完全一样,那么我们就可以直接加载别人的模型,还有一种情况,我们在训练自己模型的过程中,突然中断了,但只要我们保存了之前的模型的参数也可以使用下面的代码直

  • 浅谈Volley加载不出图片的问题

    问题分析:加载后台图片的时候,发现加载不出来,后来发现图片的url格式是: http://192.168.1.71/\carhome\shop\778c2bc3ec0a49e1969b24b3a8e62f31\detail\DSC04209.JPG 因为Volley请求,不识别url中有"\" 所以需要把"|"替换成"/" 以下是工具类 public class StringUtil { public static String formatUr

  • 浅谈BeanPostProcessor加载次序及其对Bean造成的影响分析

    前言 BeanPostProcessor是一个工厂钩子,允许Spring框架在新创建Bean实例时对其进行定制化修改.例如:通过检查其标注的接口或者使用代理对其进行包裹.应用上下文会从Bean定义中自动检测出BeanPostProcessor并将它们应用到随后创建的任何Bean上. 普通Bean对象的工厂允许在程序中注册post-processors,应用到随后在本工厂中创建的所有Bean上.典型的场景如:post-processors使用postProcessBeforeInitializat

  • 浅谈vue加载优化策略

    vue.js是一个比较流行的前端框架,与react.js.angular.js相比来说,vue.js入手曲线更加流畅,不管掌握多少都可以快速上手.但是单页面应用也都有其弊病,有时候首屏加载慢的让人捏舌.今天我们以vue cli3.x来说一说如何行之有效的缓解此问题! 方法一 路由懒加载 首屏加载慢的原因无非就是单页面应用需要加载完整个路由表上的页面,而路由懒加载就是来解决这个问题的.如果我们能把不同路由对应的组件分割成不同的代码块,然后当路由被访问的时候才加载对应组件,这样就更加高效了.下面这个

  • 浅谈vue-cli加载不到dev-server.js的解决办法

    在使用vue开发过程中,难免需要去本地数据地址进行请求,而原版配置在dev-server.js中,新版vue-webpack-template已经删除dev-server.js,改用webpack.dev.conf.js代替,所以 配置本地访问在webpack.dev.conf.js里配置即可. #webpack.dev.conf.js //首先 // nodejs开发框架express,用来简化操作 const express = require('express') // 创建node.js

  • Tensorflow加载预训练模型和保存模型的实例

    使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我们可能也需要用到别人训练好的模型,并在这个基础上再次训练.这时候我们需要掌握如何操作这些模型数据.看完本文,相信你一定会有收获! 1 Tensorflow模型文件 我们在checkpoint_dir目录下保存的文件结构如下: |--checkpoint_dir | |--checkpoint | |--MyModel.meta | |--MyModel.data-00000-of-00001 | |--MyModel.in

  • 浅谈tensorflow之内存暴涨问题

    在用tensorflow实现一些模型的时候,有时候我们在运行程序的时候,会发现程序占用的内存在不断增长.最后内存溢出,程序被kill掉了. 这个问题,其实有两个可能性.一个是比较常见,同时也是很难发现的.这个问题的解决,需要我们知道tensorflow在构图的时候,是没有所谓的临时变量的,只要有operator.那么tensorflow就会在构建的图中增加这个operator所代表的节点.所以,在运行程序的过程中,内存不断增长的原因就是在模型训练迭代的过程中,tensorflow一直在帮你增加图

  • 浅谈tensorflow模型保存为pb的各种姿势

    一,直接保存pb 1, 首先我们当然可以直接在tensorflow训练中直接保存为pb为格式,保存pb的好处就是使用场景是实现创建模型与使用模型的解耦,使得创建模型与使用模型的解耦,使得前向推导inference代码统一.另外的好处就是保存为pb的时候,模型的变量会变成固定的,导致模型的大小会大大减小. 这里稍稍解释下pb:是MetaGraph的protocol buffer格式的文件,MetaGraph包括计算图,数据流,以及相关的变量和输入输出 主要使用tf.SavedModelBuilde

随机推荐