keras .h5转移动端的.tflite文件实现方式

以前tensorflow有bug 在winodws下无法转,但现在好像没有问题了,代码如下

将keras 下的mobilenet_v2转成了tflite

from keras.backend import clear_session
import numpy as np
import tensorflow as tf
clear_session()
np.set_printoptions(suppress=True)
input_graph_name = "../models/weights.best_mobilenet224.h5"
output_graph_name = input_graph_name[:-3] + '.tflite'
converter = tf.lite.TFLiteConverter.from_keras_model_file(model_file=input_graph_name)
converter.post_training_quantize = True
#在windows平台这个函数有问题,无法正常使用
tflite_model = converter.convert()
open(output_graph_name, "wb").write(tflite_model)
print ("generate:",output_graph_name)

补充知识:如何把Tensorflow模型转换成TFLite模型

深度学习迅猛发展,目前已经可以移植到移动端使用了,TensorFlow推出的TensorFlow Lite就是一款把深度学习应用到移动端的框架技术。

使用TensorFlowLite 需要tflite文件模型,这个模型可以由TensorFlow训练的模型转换而成。所以首先需要知道如何保存训练好的TensorFlow模型。

一般有这几种保存形式:

1、Checkpoints

2、HDF5

3、SavedModel等

保存与读取CheckPoint

当模型训练结束,可以用以下代码把权重保存成checkpoint格式

model.save_weights('./MyModel',True)

checkpoints文件仅是保存训练好的权重,不带网络结构,所以做predict时需要结合model使用

如:

model = keras_segmentation.models.segnet.mobilenet_segnet(n_classes=2, input_height=224, input_width=224)
model.load_weights('./MyModel')

保存成H5

把训练好的网络保存成h5文件很简单

model.save('MyModel.h5')

H5转换成TFLite

这里是文章主要内容

我习惯使用H5文件转换成tflite文件

官网代码是这样的

converter = tf.lite.TFLiteConverter.from_keras_model_file('newModel.h5')
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)

但我用的keras 2.2.4版本会报下面错误,好像说是新版的keras把relu6改掉了,找不到方法

ValueError: Unknown activation function:relu6

于是需要自己定义一个relu6

import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.utils import CustomObjectScope

def relu6(x):
 return K.relu(x, max_value=6)

with CustomObjectScope({'relu6': relu6}):
  converter = tf.lite.TFLiteConverter.from_keras_model_file('newModel.h5')
  tflite_model = converter.convert()
  open("newModel.tflite", "wb").write(tflite_model)

看到生成的tflite文件表示保存成功了

也可以这么查看tflite网络的输入输出

import numpy as np
import tensorflow as tf

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="newModel.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print(input_details)
print(output_details)

输出了以下信息

[{'name': 'input_1', 'index': 115, 'shape': array([ 1, 224, 224, 3]), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]

[{'name': 'activation_1/truediv', 'index': 6, 'shape': array([ 1, 12544, 2]), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]

两个shape分别表示输入输出的numpy数组结构,dtype是数据类型

以上这篇keras .h5转移动端的.tflite文件实现方式)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • 在keras下实现多个模型的融合方式

    在网上搜过发现关于keras下的模型融合框架其实很简单,奈何网上说了一大堆,这个东西官方文档上就有,自己写了个demo: # Function:基于keras框架下实现,多个独立任务分类 # Writer: PQF # Time: 2019/9/29 import numpy as np from keras.layers import Input, Dense from keras.models import Model import tensorflow as tf # 生成训练集 data

  • 关于keras.layers.Conv1D的kernel_size参数使用介绍

    今天在用keras添加卷积层的时候,发现了kernel_size这个参数不知怎么理解,keras中文文档是这样描述的: kernel_size: 一个整数,或者单个整数表示的元组或列表, 指明 1D 卷积窗口的长度. 又经过多方查找,大体理解如下: 因为是添加一维卷积层Conv1D(),一维卷积一般会处理时序数据,所以,卷积核的宽度为1,而kernel_size就是卷积核的长度了,这样的意思就是这个卷积核是一个长方形的卷积核. 补充知识:tf.layers.conv1d函数解析(一维卷积) 一维

  • 解决Keras 与 Tensorflow 版本之间的兼容性问题

    在利用Keras进行实验的时候,后端为Tensorflow,出现了以下问题: 1. 服务器端激活Anaconda环境跑程序时,实验结果很差. 环境:tensorflow 1.4.0,keras 2.1.5 2. 服务器端未激活Anaconda环境跑程序时,实验结果回到正常值. 环境:tensorflow 1.7.0,keras 2.0.8 3. 自己PC端跑相同程序时,实验结果回到正常值. 环境:tensorflow 1.6.0,keras 2.1.5 怀疑实验结果的异常性是由于Keras和Te

  • keras .h5转移动端的.tflite文件实现方式

    以前tensorflow有bug 在winodws下无法转,但现在好像没有问题了,代码如下 将keras 下的mobilenet_v2转成了tflite from keras.backend import clear_session import numpy as np import tensorflow as tf clear_session() np.set_printoptions(suppress=True) input_graph_name = "../models/weights.b

  • javascript html5移动端轻松实现文件上传

    PC端上传文件多半用插件,引入flash都没关系,但是移动端要是还用各种冗余的插件估计得被喷死,项目里面需要做图片上传的功能,既然H5已经有相关的接口且兼容性良好,当然优先考虑用H5来实现. 用的技术主要是: ajax FileReader FormData HTML结构: <div class="camera-area"> <form enctype="multipart/form-data" method="post">

  • Node.js上传文件功能之服务端如何获取文件上传进度

    内容概述 multer是常用的Express文件上传中间件.服务端如何获取文件上传的进度,是使用的过程中,很常见的一个问题.在SF上也有同学问了类似问题<nodejs multer有没有查看文件上传进度的方法?>.稍微回答了下,这里顺便整理出来,有同样疑问的同学可以参考. 下文主要介绍如何利用progress-stream获取文件上传进度,以及该组件使用过程中的注意事项. 利用progress-stream获取文件上传进度 如果只是想在服务端获取上传进度,可以试下如下代码.注意,这个模块跟Ex

  • Django实现web端tailf日志文件功能及实例详解

    这是Django Channels系列文章的第二篇,以web端实现tailf的案例讲解Channels的具体使用以及跟Celery的结合 通过上一篇 <Django使用Channels实现WebSocket--上篇> 的学习应该对Channels的各种概念有了清晰的认知,可以顺利的将Channels框架集成到自己的Django项目中实现WebSocket了,本篇文章将以一个Channels+Celery实现web端tailf功能的例子更加深入的介绍Channels 先说下我们要实现的目标:所有

  • Android和PC端通过局域网文件同步

    本文为大家分享了Android和PC端通过局域网文件同步的具体代码,供大家参考,具体内容如下 public class FileOptions { public String name; public String path; public long size; } //Activity public class MainActivity extends Activity { private TextView tvMsg; private EditText logShow, filePath;

  • keras小技巧——获取某一个网络层的输出方式

    前言: keras默认提供了如何获取某一个层的某一个节点的输出,但是没有提供如何获取某一个层的输出的接口,所以有时候我们需要获取某一个层的输出,则需要自己编写代码,但是鉴于keras高层封装的特性,编写起来实际上很简单,本文提供两种常见的方法来实现,基于上一篇文章的模型和代码: keras自定义回调函数查看训练的loss和accuracy 一.模型加载以及各个层的信息查看 从前面的定义可知,参见上一篇文章,一共定义了8个网络层,定义如下: model.add(Convolution2D(filt

  • C/C++读取大文件数据方式详细讲解

    目录 前言 第一种方法 第二种方法 第三种方法 解决 前言 以前对C语言与C++不够了解时,我无法知道如何完整获取一个文件的所有数据并且不遗漏掉. 在网络上也搜索了很多很多的相关帖子,但是没有一个是真正有用的. 本文章使用C语言进行演示,如需使用C++的话原理为一样的. 以下列出那些没用的代码 第一种方法 // 创建一个变量,然后使用FILE指针打开一个文件 // 用fgetc函数与循环代码不断将数据读取到变量中 uint8_t data[4096]; FILE *fp = fopen("文件路

  • 浅谈js文件引用方式及其同步执行与异步执行

    任何以appendChild(scriptNode) 的方式引入的js文件都是异步执行的 (scriptNode 需要插入document中,只创建节点和设置 src 是不会加载 js 文件的,这跟 img 的与加载不同 ) html文件中的<script>标签中的代码或src引用的js文件中的代码是同步加载和执行的 html文件中的<script>标签中的代码使用document.write()方式引入的js文件是异步执行的 html文件中的<script>标签src

  • 详解Python中open()函数指定文件打开方式的用法

    文件打开方式 当我们用open()函数去打开文件的时候,有好几种打开的模式. 'r'->只读 'w'->只写,文件已存在则清空,不存在则创建. 'a'->追加,写到文件末尾 'b'->二进制模式,比如打开图像.音频.word文件. '+'->更新(可读可写) 这个带'+'号的有点难以理解,上代码感受下. with open('foo.txt', 'w+') as f: f.write('bar\n') f.seek(0) data = f.read() 可以看到,上面这段代码

  • GO语言常用的文件读取方式

    本文实例讲述了GO语言常用的文件读取方式.分享给大家供大家参考.具体分析如下: Golang 的文件读取方法很多,刚上手时不知道怎么选择,所以贴在此处便后速查. 一次性读取 小文件推荐一次性读取,这样程序更简单,而且速度最快. 复制代码 代码如下: func ReadAll(filePth string) ([]byte, error) {  f, err := os.Open(filePth)  if err != nil {   return nil, err  } return iouti

随机推荐