Tensorflow2.1 MNIST图像分类实现思路分析

目录
  • 前言
  • 主要思路和实现
    • (1) 加载数据,处理数据
    • (2) 使用 keras 搭建深度学习模型
    • (3) 定义损失函数
    • (4) 配置编译模型
    • (5) 使用训练数据训练模型
    • (6) 使用测试数据评估模型
    • (7) 展示不使用归一化的操作的训练和评估结果

前言

之前工作中主要使用的是 Tensorflow 1.15 版本,但是渐渐跟不上工作中的项目需求了,而且因为 2.x 版本和 1.x 版本差异较大,所以要专门花时间学习一下 2.x 版本,本文作为学习 Tensorflow 2.x 版本的开篇,主要介绍了使用 cpu 版本的 Tensorflow 2.1 搭建深度学习模型,完成对于 MNIST 数据的图片分类的任务。

主要思路和实现

(1) 加载数据,处理数据

这里是要导入 tensorflow 的包,前提是你要提前安装 tensorflow ,我这里为了方便直接使用的是 cpu 版本的 tensorflow==2.1.0 ,如果是为了学习的话,cpu 版本的也够用了,毕竟数据量和模型都不大。

import tensorflow as tf

这里是为了加载 mnist 数据集,mnist 数据集里面就是 0-9 这 10 个数字的图片集,我们要使用深度学习实现一个模型完成对 mnist 数据集进行分类的任务,这个项目相当于 java 中 hello world 。

mnist = tf.keras.datasets.mnist

这里的 (x_train, y_train) 表示的是训练集的图片和标签,(x_test, y_test) 表示的是测试集的图片和标签。

(x_train, y_train), (x_test, y_test) = mnist.load_data()

每张图片是 28*28 个像素点(数字)组成的,而每个像素点(数字)都是 0-255 中的某个数字,我们对其都除 255 ,这样就是相当于对这些图片的像素点值做归一化,这样有利于模型加速收敛,在本项目中执行本操作比不执行本操作最后的准确率高很多,在文末会展示注释本行情况下,模型评估的指标结果,大家可以自行对比差异。

x_train, x_test = x_train / 255.0, x_test / 255.0

(2) 使用 keras 搭建深度学习模型

这里主要是要构建机器学习模型,模型分为以下几层:

  • 第一层要接收图片的输入,每张图片是 28*28 个像素点组成的,所以 input_shape=(28, 28)
  • 第二层是一个输出 128 维度的全连接操作
  • 第三层是要对第二层的输出随机丢弃 20% 的 Dropout 操作,这样有利于模型的泛化

第四层是一个输出 10 维度的全连接操作,也就是预测该图片分别属于这十种类型的概率

 model = tf.keras.models.Sequential([
   tf.keras.layers.Flatten(input_shape=(28, 28)),
   tf.keras.layers.Dense(128, activation='relu'),
   tf.keras.layers.Dropout(0.2),
   tf.keras.layers.Dense(10)
 ])

(3) 定义损失函数

这里主要是定义损失函数,这里的损失函数使用到了 SparseCategoricalCrossentropy ,主要是为了计算标签和预测结果之间的交叉熵损失。

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

(4) 配置编译模型

这里主要是配置和编译模型,优化器使用了 adam ,要优化的评价指标选用了准确率 accuracy ,当然了还可以选择其他的优化器和评价指标。

model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])

(5) 使用训练数据训练模型

这里主要使用训练数据的图片和标签来训练模型,将整个训练样本集训练 5 次。

model.fit(x_train, y_train, epochs=5)

训练过程结果输出如下:

Train on 60000 samples
Epoch 1/5
60000/60000 [==============================] - 3s 43us/sample - loss: 0.2949 - accuracy: 0.9144
Epoch 2/5
60000/60000 [==============================] - 2s 40us/sample - loss: 0.1434 - accuracy: 0.9574
Epoch 3/5
60000/60000 [==============================] - 2s 36us/sample - loss: 0.1060 - accuracy: 0.9676
Epoch 4/5
60000/60000 [==============================] - 2s 31us/sample - loss: 0.0891 - accuracy: 0.9721
Epoch 5/5
60000/60000 [==============================] - 2s 29us/sample - loss: 0.0740 - accuracy: 0.9771
10000/10000 - 0s - loss: 0.0744 - accuracy: 0.9777

(6) 使用测试数据评估模型

这里主要是使用测试数据中的图片和标签来评估模型,verbose 可以选为 0、1、2 ,区别主要是结果输出的形式不一样,嫌麻烦可以不设置

model.evaluate(x_test,  y_test, verbose=2)

评估的损失值和准确率如下:

[0.07444974237508141, 0.9777]

(7) 展示不使用归一化的操作的训练和评估结果

在不使用归一化操作的情况下,训练过程输出如下:

Train on 60000 samples
Epoch 1/5
60000/60000 [==============================] - 3s 42us/sample - loss: 2.4383 - accuracy: 0.7449
Epoch 2/5
60000/60000 [==============================] - 2s 40us/sample - loss: 0.5852 - accuracy: 0.8432
Epoch 3/5
60000/60000 [==============================] - 2s 36us/sample - loss: 0.4770 - accuracy: 0.8724
Epoch 4/5
60000/60000 [==============================] - 2s 34us/sample - loss: 0.4069 - accuracy: 0.8950
Epoch 5/5
60000/60000 [==============================] - 2s 32us/sample - loss: 0.3897 - accuracy: 0.8996
10000/10000 - 0s - loss: 0.2898 - accuracy: 0.9285

评估结果输入如下:

[0.2897613683119416, 0.9285]

所以我们通过和上面的进行对比发现,不进行归一化操作,在训练过程中收敛较慢,在相同 epoch 的训练之后,评估的准确率和损失值都不理想,损失值比第(6)步操作的损失值大,准确率比第(6)步操作低 5% 左右。

以上就是Tensorflow2.1 MNIST图像分类实现思路分析的详细内容,更多关于Tensorflow2.1 MNIST图像分类的资料请关注我们其它相关文章!

(0)

相关推荐

  • TensorFlow神经网络创建多层感知机MNIST数据集

    前面使用TensorFlow实现一个完整的Softmax Regression,并在MNIST数据及上取得了约92%的正确率. 前文传送门: TensorFlow教程Softmax逻辑回归识别手写数字MNIST数据集 现在建含一个隐层的神经网络模型(多层感知机). import tensorflow as tf import numpy as np import input_data mnist = input_data.read_data_sets('data/', one_hot=True)

  • Tensorflow 2.1完成对MPG回归预测详解

    目录 前言 1. 获取 Auto MPG 数据并进行数据的归一化处理 2. 对数据进行处理 搭建深度学习模型 使用 EarlyStoping 完成模型训练 使用测试数据对模型进行评估 使用模型进行预测 展示没有进行归一化操作的训练过程 前言 本文的主要内容是使用 cpu 版本的 tensorflor-2.1 完成对 Auto MPG 数据集的回归预测任务. 本文大纲 获取 Auto MPG 数据 对数据进行处理 搭建深度学习模型.并完成模型的配置和编译 使用 EarlyStoping 完成模型训

  • TensorFlow教程Softmax逻辑回归识别手写数字MNIST数据集

    基于MNIST数据集的逻辑回归模型做十分类任务 没有隐含层的Softmax Regression只能直接从图像的像素点推断是哪个数字,而没有特征抽象的过程.多层神经网络依靠隐含层,则可以组合出高阶特征,比如横线.竖线.圆圈等,之后可以将这些高阶特征或者说组件再组合成数字,就能实现精准的匹配和分类. import tensorflow as tf import numpy as np import input_data print('Download and Extract MNIST datas

  • TensorFlow卷积神经网络MNIST数据集实现示例

    这里使用TensorFlow实现一个简单的卷积神经网络,使用的是MNIST数据集.网络结构为:数据输入层–卷积层1–池化层1–卷积层2–池化层2–全连接层1–全连接层2(输出层),这是一个简单但非常有代表性的卷积神经网络. import tensorflow as tf import numpy as np import input_data mnist = input_data.read_data_sets('data/', one_hot=True) print("MNIST ready&q

  • Tensorflow2.1实现文本中情感分类实现解析

    目录 前言 实现过程和思路解析 下载影评数据并进行 padding 处理 创建验证集数据 搭建简单的深度学习模型 配置并编译模型 训练模型 评估模型 前言 本文主要是用 cpu 版本的 tensorflow 2.1 搭建深度学习模型,完成对电影评论的情感分类任务. 本次实践的数据来源于IMDB 数据集,里面的包含的是电影的影评,每条影评评论文本分为积极类型或消极类型.数据集总共包含 50000 条影评文本,取该数据集的 25000 条影评数据作为训练集,另外 25000 条作为测试集,训练集与测

  • Tensorflow 2.4加载处理图片的三种方式详解

    目录 前言 数据准备 使用内置函数读取并处理磁盘数据 自定义方式读取和处理磁盘数据 从网络上下载数据 前言 本文通过使用 cpu 版本的 tensorflow 2.4 ,介绍三种方式进行加载和预处理图片数据. 这里我们要确保 tensorflow 在 2.4 版本以上 ,python 在 3.8 版本以上,因为版本太低有些内置函数无法使用,然后要提前安装好 pillow 和 tensorflow_datasets ,方便进行后续的数据加载和处理工作. 由于本文不对模型进行质量保证,只介绍数据的加

  • Tensorflow2.1 MNIST图像分类实现思路分析

    目录 前言 主要思路和实现 (1) 加载数据,处理数据 (2) 使用 keras 搭建深度学习模型 (3) 定义损失函数 (4) 配置编译模型 (5) 使用训练数据训练模型 (6) 使用测试数据评估模型 (7) 展示不使用归一化的操作的训练和评估结果 前言 之前工作中主要使用的是 Tensorflow 1.15 版本,但是渐渐跟不上工作中的项目需求了,而且因为 2.x 版本和 1.x 版本差异较大,所以要专门花时间学习一下 2.x 版本,本文作为学习 Tensorflow 2.x 版本的开篇,主

  • jQuery.prototype.init选择器构造函数源码思路分析

    一.源码思路分析总结 概要: jQuery的核心思想可以简单概括为"查询和操作dom",今天主要是分析一下jQuery.prototype.init选择器构造函数,处理选择器函数中的参数: 这个函数的参数就是jQuery()===$()执行函数中的参数,可以先看我之前写的浅析jQuery基础框架一文,了解基础框架后,再看此文. 思路分析: 以下是几种jQuery的使用情况(用于查询dom),每种情况都返回一个选择器实例(习惯称jQuery对象(一个nodeList对象),该对象包含查询

  • jQuery.clean使用方法及思路分析

    一.jQuery.clean使用方法jQuery.clean( elems, context, fragment, scripts );二.思路分析1.处理参数context,确保其为文档根节点document2.处理参数elems数组(循环遍历数组) 2.1.elem为数字,转换为字符串 2.2.elem为非法值,跳出本次循环 2.3.elem为字符串 2.4.字符串不存在实体编号或html标签,则创建文本节点 2.5.字符串为实体编号或html标签 复制代码 代码如下: 创建一个div元素并

  • jQuery.buildFragment使用方法及思路分析

    一.jQuery.buildFragment使用方法 1.参数 jQuery.buildFragment( args, context, scripts );2.返回值 return { fragment: fragment, cacheable: cacheable }; 二.思路分析 1.处理context参数 根据传入到context参数值的不同,确保context为文档根节点document 2.限制可缓存条件 2.1.字符串小于512字节 2.2.字符串不存在option标签(克隆op

  • Linux NFS服务器安装与配置思路分析

    一,nfs服务优缺点 NFS服务简介 NFS 是Network File System的缩写,即网络文件系统.一种使用于分散式文件系统的协定,由Sun公司开发,于1984年向外公布.功能是通过网络让不同的机器.不同的操作系统能够彼此分享个别的数据,让应用程序在客户端通过网络访问位于服务器磁盘中的数据,是在类Unix系统间实现磁盘文件共享的一种方法. NFS 的基本原则是"容许不同的客户端及服务端通过一组RPC分享相同的文件系统",它是独立于操作系统,容许不同硬件及操作系统的系统共同进行

  • vue-cli3+typescript新建一个项目的思路分析

    最近在用vue搭一个后台管理的单页应用的demo,因为之前只用过vue-cli2+javascript进行开发,而vue-cli3早在去年8月就已经发布,并且对于typescript有了很好地支持.所以为了熟悉新技术,我选择使用vue-cli3+typescript进行新应用的开发.这里是新技术的学习记录. 初始化项目 卸载老版本脚手架,安装新版本脚手架后,开始初始化项目.初始化的命令跟2.x版本的略有不同,以前是 vue init webpack project-name ,而现在是 vue

  • vue2.0的虚拟DOM渲染思路分析

    1.为什么需要虚拟DOM 前面我们从零开始写了一个简单的类Vue框架(文章链接),其中的模板解析和渲染是通过Compile函数来完成的,采用了文档碎片代替了直接对页面中DOM元素的操作,在完成数据的更改后通过appendChild函数将真实的DOM插入到页面. 虽然采用的是文档碎片,但是操作的还是真实的DOM. 而我们知道操作DOM的代价是昂贵的,所以vue2.0采用了虚拟DOM来代替对真实DOM的操作,最后通过某种机制来完成对真实DOM的更新,渲染视图. 所谓的虚拟DOM,其实就是 用JS来模

  • vue项目两种方式实现竖向表格的思路分析

    问题描述 在我们做项目中,常见的是横向表格,但是偶尔的需求,也会做竖向的表格.比如下图这样的竖向表格: 我们看到这样的效果图,第一时间想到的是使用UI框架,改一改搞定.但是饿了么UI并没有直接提供这样的案例,部分同学会想着使用饿了么UI中的el-table的合并行.合并列的方式去实现,其实如果这样去做的话,反而做麻烦了.比如下面的合并行合并列: 类似于这样的效果图,其实并不一定非得使用UI组件,有的时候使用原生的方式去做.反而会更方便.本文介绍两种方式去实现这样的简单的竖向表格.实际场景中可能会

  • Springboot死信队列 DLX 配置和使用思路分析

    目录 前言 什么是死信 配置和测试死信 思路分析 配置类编写 编写消息发送服务 测试 消息什么时候会成为死信消息? 总结 参考资料 代码下载 前言 上一篇博客Springboot——整合RabbitMq测试TTL中,针对设置单个消息期限或者整个队列消息期限,进行了一些配置和说明.同时也都列举了一些区别关系. 但考虑过一个问题了没有? 不管是设置哪种方式,如果消息期限到了,队列都会将该消息进行丢弃处理.这么做合适么? 假设是某个设备的重要信息,或者某个重要的订单信息,因为规定时间内未被及时消费就将

  • Vue3中的执行流程思路分析-流程图

    目录 一. 前言 二. Vue3 思路分析 1. createRender(options) 2. createApp 3. app.mount(‘#app’) 4. render(vnode, container) 5. patch(n1, n2, container) 6. processComponent 7. mountComponent 8. setupRenderEffect 9. patch 10. processElement mountElement 三. 结尾 一. 前言 本

随机推荐