一文详解CNN 解决 Flowers 图像分类任务

目录
  • 前言
  • 加载并展示数据
  • 构件处理图像的 pipeline
  • 搭建深度学习分类模型
  • 训练模型并观察结果
  • 加入了抑制过拟合措施并重新进行模型的训练和测试

前言

本文主要任务是使用通过 tf.keras.Sequential 搭建的模型进行各种花朵图像的分类,主要涉及到的内容有三个部分:

  • 使用 tf.keras.Sequential 搭建模型。
  • 使用 tf.keras.utils.image_dataset_from_directory 从磁盘中高效加载数据。
  • 使用了一定的防止过拟合的方法,如丰富训练样本的数量、在数据处理过程中加入了数据增强、全连接层加入了 Dropout 等。

本文所用的环境为 tensorlfow-cpu= 2.4 ,python 版本为 3.8 。

主要章节介绍如下:

  • 加载并展示数据
  • 构件处理图像的 pipeline
  • 搭建深度学习分类模型
  • 训练模型并观察结果
  • 加入了抑制过拟合措施并重新进行模型的训练和测试

加载并展示数据

(1)该数据需要从网上下载,需要耐心等待片刻,下载下来自动会存放在“你的主目录.keras\datasets\flower_photos”。

(2)数据中总共有 5 种类,分别是 daisy、 dandelion、roses、sunflowers、tulips,总共包含了 3670 张图片。

(3) 随机展示了一张花朵的图片。

import matplotlib.pyplot as plt
import numpy as np
import PIL
import tensorflow as tf
import pathlib
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import random
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*.jpg')))
print("总共包含%d张图片,下面随便展示一张玫瑰的图片样例:"%image_count)
roses = list(data_dir.glob('roses/*'))
PIL.Image.open(str(random.choice(roses)))

结果打印:

总共包含3670张图片,下面随便展示一张玫瑰的图片样例:

构件处理图像的 pipeline

(1)使用 tf.keras.utils.image_dataset_from_directory 可以将我们的花朵图片数据,从磁盘加载到内存中,并形成 tensorflow 高效的 tf.data.Dataset 类型。

(2)我们将数据集 shuffle 之后,进行二八比例的随机抽取分配,80% 的数据作为我们的训练集,共 2936 张图片, 20% 的数据集作为我们的测试集,共 734 张图片。

(3)我们使用 Dataset.cache 和 Dataset.prefetch 来提升数据的处理速度,使用 cache 在将数据从磁盘加载到 cache 之后,就可以将数据一直放 cache 中便于我们的后续访问,这可以保证在训练过程中数据的处理不会成为计算的瓶颈。另外使用 prefetch 可以在 GPU 训练模型的时候,CPU 将之后需要的数据提前进行处理放入 cache 中,也是为了提高数据的处理性能,加快整个训练过程,不至于训练模型时浪费时间等待数据。

(4)我们随便选取了 6 张图像进行展示,可以看到它们的图片以及对应的标签。

batch_size = 32
img_height = 180
img_width = 180
train_ds = tf.keras.utils.image_dataset_from_directory( data_dir, validation_split=0.2, subset="training", seed=1, image_size=(img_height, img_width), batch_size=batch_size)
val_ds = tf.keras.utils.image_dataset_from_directory( data_dir,  validation_split=0.2, subset="validation", seed=1, image_size=(img_height, img_width),batch_size=batch_size)
class_names = train_ds.class_names
num_classes = len(class_names)
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
plt.figure(figsize=(5, 5))
for images, labels in train_ds.take(1):
    for i in range(6):
        ax = plt.subplot(2, 3, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis("off")

结果打印:

Found 3670 files belonging to 5 classes.
Using 2936 files for training.
Found 3670 files belonging to 5 classes.
Using 734 files for validation.

搭建深度学习分类模型

(1)因为最初的图片都是 RGB 三通道图片,像素点的值在 [0,255] 之间,为了加速模型的收敛,我们要将所有的数据进行归一化操作。所以在模型的第一层加入了 layers.Rescaling 对图片进行处理。

(2)使用了三个卷积块,每个卷积块中包含了卷积层和池化层,并且每一个卷积层中都添加了 relu 激活函数,卷积层不断提取图片的特征,池化层可以有效的所见特征矩阵的尺寸,同时也可以减少最后连接层的中的参数数量,权重参数少的同时也起到了加快计算速度和防止过拟合的作用。

(3)最后加入了两层全连接层,输出对图片的分类预测 logit 。

(4)使用 Adam 作为我们的模型优化器,使用 SparseCategoricalCrossentropy 计算我们的损失值,在训练过程中观察 accuracy 指标。

model = Sequential([
  layers.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(64, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(num_classes)
])
model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])

训练模型并观察结果

(1)我们使用训练集进行模型的训练,使用验证集进行模型的验证,总共训练 5 个 epoch 。

(2)我们通过对训练过程中产生的准确率和损失值,与验证过程中产生的准确率和损失值进行绘图对比,训练时的准确率高出验证时的准确率很多,训练时的损失值远远低于验证时的损失值,这说明模型存在过拟合风险。正常的情况这两个指标应该是大体呈现同一个发展趋势。

epochs = 5
history = model.fit(train_ds, validation_data=val_ds, epochs=epochs)
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(epochs)
plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

结果打印:

Epoch 1/5
92/92 [==============================] - 45s 494ms/step - loss: 0.2932 - accuracy: 0.8992 - val_loss: 1.2603 - val_accuracy: 0.6417
Epoch 2/5
92/92 [==============================] - 40s 436ms/step - loss: 0.1814 - accuracy: 0.9414 - val_loss: 1.5241 - val_accuracy: 0.6267
Epoch 3/5
92/92 [==============================] - 36s 394ms/step - loss: 0.0949 - accuracy: 0.9745 - val_loss: 1.6629 - val_accuracy: 0.6499
Epoch 4/5
92/92 [==============================] - 48s 518ms/step - loss: 0.0554 - accuracy: 0.9860 - val_loss: 1.7566 - val_accuracy: 0.6621
Epoch 5/5
92/92 [==============================] - 39s 419ms/step - loss: 0.0341 - accuracy: 0.9918 - val_loss: 2.1150 - val_accuracy: 0.6335

加入了抑制过拟合措施并重新进行模型的训练和测试

(1)当训练样本数量较少时,通常会发生过拟合现象。我们可以操作数据增强技术,通过随机翻转、旋转等方式来增加样本的丰富程度。常见的数据增强处理方式有:tf.keras.layers.RandomFlip、tf.keras.layers.RandomRotation和 tf.keras.layers.RandomZoom。这些方法可以像其他层一样包含在模型中,并在 GPU 上运行。

(2)这里挑选了一张图片,对其进行 6 次执行数据增强,可以看到得到了经过一定程度缩放、旋转、反转的数据集。

data_augmentation = keras.Sequential([
    layers.RandomFlip("horizontal", input_shape=(img_height, img_width, 3)),
    layers.RandomRotation(0.1),
    layers.RandomZoom(0.5)
])
plt.figure(figsize=(5, 5))
for images, _ in train_ds.take(1):
    for i in range(6):
        augmented_images = data_augmentation(images)
        ax = plt.subplot(2, 3, i + 1)
        plt.imshow(augmented_images[0].numpy().astype("uint8"))
        plt.axis("off")

(3)在模型架构的开始加入数据增强层,同时在全连接层的地方加入 Dropout ,进行神经元的随机失活,这两个方法的加入可以有效抑制模型过拟合的风险。其他的模型结构、优化器、损失函数、观测值和之前相同。通过绘制数据图我们发现,使用这些措施很明显减少了过拟合的风险。

model = Sequential([
  data_augmentation,
  layers.Rescaling(1./255),
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(64, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Dropout(0.2),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(num_classes, name="outputs")
])
model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
epochs = 15
history = model.fit( train_ds, validation_data=val_ds, epochs=epochs)
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(epochs)
plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

结果打印:

92/92 [==============================] - 57s 584ms/step - loss: 1.3080 - accuracy: 0.4373 - val_loss: 1.0929 - val_accuracy: 0.5749
Epoch 2/15
92/92 [==============================] - 41s 445ms/step - loss: 1.0763 - accuracy: 0.5596 - val_loss: 1.3068 - val_accuracy: 0.5204
...
Epoch 14/15
92/92 [==============================] - 59s 643ms/step - loss: 0.6306 - accuracy: 0.7585 - val_loss: 0.7963 - val_accuracy: 0.7044
Epoch 15/15
92/92 [==============================] - 42s 452ms/step - loss: 0.6155 - accuracy: 0.7691 - val_loss: 0.8513 - val_accuracy: 0.6975

(4)最后我们使用一张随机下载的图片,用模型进行类别的预测,发现可以识别出来。

sunflower_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/592px-Red_sunflower.jpg"
sunflower_path = tf.keras.utils.get_file('Red_sunflower', origin=sunflower_url)
img = tf.keras.utils.load_img(  sunflower_path, target_size=(img_height, img_width) )
img_array = tf.keras.utils.img_to_array(img)
img_array = tf.expand_dims(img_array, 0)
predictions = model.predict(img_array)
score = tf.nn.softmax(predictions[0])
print(  "这张图片最有可能属于 {} ,有 {:.2f} 的置信度。".format(class_names[np.argmax(score)], 100 * np.max(score)))

结果打印:

这张图片最有可能属于 sunflowers ,有 97.39 的置信度。

以上就是一文详解CNN 解决 Flowers 图像分类任务的详细内容,更多关于CNN Flowers图像分类的资料请关注我们其它相关文章!

(0)

相关推荐

  • python神经网络Keras构建CNN网络训练

    目录 Keras中构建CNN的重要函数 1.Conv2D 2.MaxPooling2D 3.Flatten 全部代码 利用Keras构建完普通BP神经网络后,还要会构建CNN Keras中构建CNN的重要函数 1.Conv2D Conv2D用于在CNN中构建卷积层,在使用它之前需要在库函数处import它. from keras.layers import Conv2D 在实际使用时,需要用到几个参数. Conv2D( nb_filter = 32, nb_row = 5, nb_col = 5

  • Keras目标检测mtcnn facenet搭建人脸识别平台

    目录 什么是mtcnn和facenet 1.mtcnn 2.facenet 实现流程 一.数据库的初始化 二.实时图片的处理 1.人脸的截取与对齐 2.利用facenet对矫正后的人脸进行编码 3.将实时图片中的人脸特征与数据库中的进行比对 4.实时处理图片整体代码 全部代码: 什么是mtcnn和facenet 1.mtcnn MTCNN,英文全称是Multi-task convolutional neural network,中文全称是多任务卷积神经网络,该神经网络将人脸区域检测与人脸关键点检

  • python人工智能tensorflow构建卷积神经网络CNN

    目录 简介 隐含层介绍 1.卷积层 2.池化层 3.全连接层 具体实现代码 卷积层.池化层与全连接层实现代码 全部代码 学习神经网络已经有一段时间,从普通的BP神经网络到LSTM长短期记忆网络都有一定的了解,但是从未系统的把整个神经网络的结构记录下来,我相信这些小记录可以帮助我更加深刻的理解神经网络. 简介 卷积神经网络(Convolutional Neural Networks, CNN)是一类包含卷积计算且具有深度结构的前馈神经网络(Feedforward Neural Networks),

  • 人工智能学习PyTorch实现CNN卷积层及nn.Module类示例分析

    目录 1.CNN卷积层 2. 池化层 3.数据批量标准化 4.nn.Module类 ①各类函数 ②容器功能 ③参数管理 ④调用GPU ⑤存储和加载 ⑥训练.测试状态切换 ⑦ 创建自己的层 5.数据增强 1.CNN卷积层 通过nn.Conv2d可以设置卷积层,当然也有1d和3d. 卷积层设置完毕,将设置好的输入数据,传给layer(),即可完成一次前向运算.也可以传给layer.forward,但不推荐. 2. 池化层 池化层的核大小一般是2*2,有2种方式: maxpooling:选择数据中最大

  • Keras搭建Mask R-CNN实例分割平台实现源码

    目录 什么是Mask R-CNN Mask R-CNN实现思路 一.预测部分 1.主干网络介绍 2.特征金字塔FPN的构建 3.获得Proposal建议框 4.Proposal建议框的解码 5.对Proposal建议框加以利用(Roi Align) 6.预测框的解码 7.mask语义分割信息的获取 二.训练部分 1.建议框网络的训练 2.Classiffier模型的训练 3.mask模型的训练 训练自己的Mask-RCNN模型 1.数据集准备 2.参数修改 3.模型训练 什么是Mask R-CN

  • Python人工智能深度学习CNN

    目录 1.CNN概述 2.卷积层 3.池化层 4.全连层 1.CNN概述 CNN的整体思想,就是对图片进行下采样,让一个函数只学一个图的一部分,这样便得到少但是更有效的特征,最后通过全连接神经网络对结果进行输出. 整体架构如下: 输入图片 →卷积:得到特征图(激活图) →ReLU:去除负值 →池化:缩小数据量同时保留最有效特征 (以上步骤可多次进行) →输入全连接神经网络 2.卷积层 CNN-Convolution 卷积核(或者被称为kernel, filter, neuron)是要被学出来的,

  • 一文详解CNN 解决 Flowers 图像分类任务

    目录 前言 加载并展示数据 构件处理图像的 pipeline 搭建深度学习分类模型 训练模型并观察结果 加入了抑制过拟合措施并重新进行模型的训练和测试 前言 本文主要任务是使用通过 tf.keras.Sequential 搭建的模型进行各种花朵图像的分类,主要涉及到的内容有三个部分: 使用 tf.keras.Sequential 搭建模型. 使用 tf.keras.utils.image_dataset_from_directory 从磁盘中高效加载数据. 使用了一定的防止过拟合的方法,如丰富训

  • Python中str is not callable问题详解及解决办法

    Python中str is not callable问题详解及解决办法 问题提出: 在Python的代码,在运行过程中,碰到了一个错误信息: python代码: def check_province_code(province, country): num = len(province) while num <3: province = ''.join([str(0),province]) num = num +1 return country + province 运行的错误信息: check

  • 详解Java中NullPointerException异常的原因详解以及解决方法

    NullPointerException是当您尝试使用指向内存中空位置的引用(null)时发生的异常,就好像它引用了一个对象一样. 当我们声明引用变量(即对象)时,实际上是在创建指向对象的指针.考虑以下代码,您可以在其中声明基本类型的整型变量x: int x; x = 10; 在此示例中,变量x是一个整型变量,Java将为您初始化为0.当您在第二行中将其分配给10时,值10将被写入x指向的内存中. 但是,当您尝试声明引用类型时会发生不同的事情.请使用以下代码: Integer num; num

  • 一文详解JS私有属性的6种实现方式

    目录 _prop Proxy Symbol WeakMap #prop ts private 总结 class 是创建对象的模版,由一系列属性和方法构成,用于表示对同一概念的数据和操作. 有的属性和方法是对外的,但也有的是只想内部用的,也就是私有的,那怎么实现私有属性和方法呢? 不知道大家会怎么实现,我梳理了下,我大概用过 6 种方式,我们分别来看一下: _prop 区分私有和公有最简单的方式就是加个下划线 _,从命名上来区分. 比如: class Dong { constructor() {

  • 一文详解Java中的类加载机制

    目录 一.前言 二.类加载的时机 2.1 类加载过程 2.2 什么时候类初始化 2.3 被动引用不会初始化 三.类加载的过程 3.1 加载 3.2 验证 3.3 准备 3.4 解析 3.5 初始化 四.父类和子类初始化过程中的执行顺序 五.类加载器 5.1 类与类加载器 5.2 双亲委派模型 5.3 破坏双亲委派模型 六.Java模块化系统 一.前言 Java虚拟机把描述类的数据从Class文件加载到内存,并对数据进行校验.转换解析和初始化,最 终形成可以被虚拟机直接使用的Java类型,这个过程

  • 一文详解Java线程的6种状态与生命周期

    目录 1.线程状态(生命周期) 2.操作线程状态 2.1.新创建状态(NEW) 2.2.可运行状态(RUNNABLE) 2.3.被阻塞状态(BLOCKED) 2.4.等待唤醒状态(WAITING) 2.5.计时等待状态(TIMED_WAITING) 2.6.终止(TERMINATED) 3.查看线程的6种状态 1.线程状态(生命周期) 一个线程在给定的时间点只能处于一种状态. 线程可以有如下6 种状态: New (新创建):未启动的线程: Runnable (可运行):可运行的线程,需要等待操作

  • 一文详解Java线程中的安全策略

    目录 一.不可变对象 二.线程封闭 三.线程不安全类与写法 四.线程安全-同步容器 1. ArrayList -> Vector, Stack 2. HashMap -> HashTable(Key, Value都不能为null) 3. Collections.synchronizedXXX(List.Set.Map) 五.线程安全-并发容器J.U.C 1. ArrayList -> CopyOnWriteArrayList 2.HashSet.TreeSet -> CopyOnW

  • 一文详解Java中Stream流的使用

    目录 简介 操作1:创建流 操作2:中间操作 筛选(过滤).去重 映射 排序 消费 操作3:终止操作 匹配.最值.个数 收集 规约 简介 说明 本文用实例介绍stream的使用. JDK8新增了Stream(流操作) 处理集合的数据,可执行查找.过滤和映射数据等操作. 使用Stream API 对集合数据进行操作,就类似于使用 SQL 执行的数据库查询.可以使用 Stream API 来并行执行操作. 简而言之,Stream API 提供了一种高效且易于使用的处理数据的方式. 特点 不是数据结构

  • 一文详解Vue3响应式原理

    目录 回顾 vue2.x 的响应式 vue3的响应式 Reflect 回顾 vue2.x 的响应式 实现原理: 对象类型:通过object.defineProperty()对属性的读取.修改进行拦截(数据劫持) 数组类型:通过重写更新数组的一系列方法来实现拦截(对数组的变更方法进行了包裹) Object.defineProperty(data,'count ",{ get(){}, set(){} }) 存在问题: 新增属性.删除属性,界面不会更新 直接通过下标修改数组,界面不会自动更新 但是

  • 一文详解Spring如何控制Bean注入的顺序

    目录 简介 构造方法依赖(推荐) @DependsOn(不推荐) BeanPostProcessor(不推荐) 简介 说明 本文介绍Spring如何控制Bean注入的顺序. 首先需要说明的是:在Bean上加@Order(xxx)是无法控制bean注入的顺序的! 控制bean的加载顺序的方法 1.构造方法依赖 2.@DependsOn 注解 3.BeanPostProcessor 扩展 Bean初始化顺序与类加载顺序基本一致:静态变量/语句块=> 实例变量或初始化语句块=> 构造方法=>

随机推荐