Tensorflow2.10实现图像分割任务示例详解

目录
  • 前言
  • 准备
  • 大纲
  • 实现
    • 1. 获取数据
    • 2. 处理数据
    • 3. 搭建模型
    • 4. 编译、训练模型
    • 5. 预测

前言

图像分割在医学成像、自动驾驶汽车和卫星成像等方面有很多应用,本质其实就是图像像素分类任务,也就是使用深度学习模型为输入图像的每个像素分配一个标签(或类)。

准备

本文的准备如下,使用 pip 安装如下配置:

  • pip install git+github.com/tensorflow/…
  • pip install tensorflow == 2.10.1
  • pip install tensorflow_datasets == 4.7.8
  • pip install ipython == 8.6.0
  • pip install matplotlib == 3.6.2

大纲

  • 获取数据
  • 处理数据
  • 搭建模型
  • 编译、训练模型
  • 预测

实现

1. 获取数据

(1)本文使用的数据集是 Oxford-IIIT Pet Dataset ,该数据集由 37 类宠物的图像组成,每个品种有 200 个图像(训练集和测试集各有 100 个),每个像素都会被划入以下三个类别之一:

  • 属于宠物的像素
  • 宠物边缘的像素
  • 其他位置的像素

(2)可以使用 TensorFlow 的内置函数从网络上下载本次使用的数据 oxford_iiit_pet ,一般会下载到本地目录 :C:\Users\【用户目录】\tensorflow_datasets\oxford_iiit_pet 。

(3)dataset 中存放是训练集和测试集这两个数据集,info 中存放的是该数据的基本信息,如文件大小,数据介绍等基本信息。

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix
from IPython.display import clear_output
import matplotlib.pyplot as plt
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

2. 处理数据

(1)normalize 函数主要是完成将图像颜色值被归一化到 [0,1] 范围,掩码像素的所属标签被标记为 {1, 2, 3}。为了方便后面的模型计算,将它们分别减去 1,得到的标签为:{0, 1, 2} 。

(2)load_image 函数主要是将每个图片的输入和掩码图片,使用指定的方法将其大小调整为指定的 128x128 。

(3)从 dataset 中分理处训练集 train_images 和测试集 test_images 。

def normalize(input_image, input_mask):
    input_image = tf.cast(input_image, tf.float32) / 255.0
    input_mask -= 1
    return input_image, input_mask
def load_image(image):
    input_image = tf.image.resize(image['image'], (128, 128))
    input_mask = tf.image.resize(image['segmentation_mask'], (128, 128))
    input_image, input_mask = normalize(input_image, input_mask)
    return input_image, input_mask
train_images = dataset['train'].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
test_images = dataset['test'].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)

(4)为了保证在加载数据的时候不会出现 I/O 不会阻塞,我们在从磁盘加载完数据之后,使用 cache 会将数据保存在内存中,确保在训练模型过程中数据的获取不会成为训练速度的瓶颈。

如果说要保存的数据量太大,可以使用 cache 创建磁盘缓存提高数据的读取效率。另外我们还使用 prefetch 在训练过程中可以并行执行数据的预获取。

TRAIN_LENGTH = info.splits['train'].num_examples
BATCH_SIZE = 32
BUFFER_SIZE = 1000
train_batches = (train_images.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat().prefetch(buffer_size=tf.data.AUTOTUNE))
test_batches = test_images.batch(BATCH_SIZE)

(5)这里的 display 函数主要是将每个样本的宠物图像、对应的掩码图像、预测的掩码图像绘制出来,在这里我们只随机挑选了一个样本进行显示。因为这里还没有预测的掩码图像,所以没有将其绘制出来。

(6)我们可以看到左侧是一张宠物的生活照,右边是一张该宠物在照片中的轮廓线图,宠物的样子所处的像素为紫色,宠物的轮廓边缘线的像素是黄色,背景的像素是墨绿色,这其实对应了图片中的像素会分成三个类别。

def display(display_list):
    plt.figure(figsize=(15, 15))
    title = ['Input Image', 'True Mask', 'Predicted Mask']
    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
        plt.axis('off')
    plt.show()
for images, masks in train_batches.take(1):
    sample_image, sample_mask = images[0], masks[0]
    display([sample_image, sample_mask])

3. 搭建模型

(1)这里使用的模型是修改后的 U-Net ,详细内容可看链接。U-Net 由编码器(下采样器)和解码器(上采样器)组成。为了学习稳健的特征并减少可训练参数的数量,请使用预训练模型 MobileNetV2 作为编码器。对于解码器,您将使用上采样块,该块已在 TensorFlow Examples 仓库的 pix2pix 示例中实现。

(2)如前所述,编码器是一个预训练的 MobileNetV2 模型。您将使用来自 tf.keras.applications 的模型。编码器由模型中中间层的特定输出组成。请注意,在训练过程中不会训练编码器。

(3)我们这里使用模型由两部分组成, 一个是编码器 down_stack(也就是下采样器),另一个是解码器 up_stack (也就是上采样器)。我们这里使用预训练的模型 MobileNetV2 作为编码器, MobileNetV2 模型可以直接从网络上下载到本地使用,使用它来进行图片的特征抽取,需要注意的是我们这里选取了模型中的若干中间层,将其作为模型的输出,而且在训练过程中我们设置了不会去训练编码器模型中的权重。对于解码器,我们使用已经在仓库实现了的 pix2pix 。

(4)我们的 U-Net 网络接收的每张图片大小为 [128, 128, 3] ,先通过模型进行下采样,然后计算上采样和 skip 的特征连接,最后经过一层 Conv2DTranspose 输出一个大小为 [batch_size, 128, 128, 3] 的向量结果。

base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)
layer_names = [ 'block_1_expand_relu', 'block_3_expand_relu', 'block_6_expand_relu', 'block_13_expand_relu', 'block_16_project']
base_model_outputs = [base_model.get_layer(name).output for name in layer_names]
down_stack = tf.keras.Model(inputs=base_model.input, outputs=base_model_outputs)
down_stack.trainable = False
up_stack = [  pix2pix.upsample(512, 3),  pix2pix.upsample(256, 3),   pix2pix.upsample(128, 3),   pix2pix.upsample(64, 3)]
def unet_model(output_channels:int):
    inputs = tf.keras.layers.Input(shape=[128, 128, 3])
    skips = down_stack(inputs)
    x = skips[-1]
    skips = reversed(skips[:-1])
    for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = tf.keras.layers.Concatenate()
        x = concat([x, skip])
    last = tf.keras.layers.Conv2DTranspose( filters=output_channels, kernel_size=3, strides=2, padding='same')
    x = last(x)
    return tf.keras.Model(inputs=inputs, outputs=x)

4. 编译、训练模型

(1)因为每个像素面临的是一个多类分类问题,所以我们使用 SparseCategoricalCrossentropy 作为损失函数,计算多分类问题的交叉熵,并将 from_logits 参数设置为 True,因为标签是用 0、1、2 三个整数表示。SparseCategoricalCrossentropy 函数中当 from_logits=true 时,会先对预测值进行 Softmax 概率化,就无须在模型最后添加 Softmax 层,我们只需要使用经过 Softmax 输出的小数和真实整数标签来计算损失即可。reduction 默认设置为 auto 时,会对一个 batch 的样本损失值求平均。

举例:

y_true = [0,1,2]
y_pred = [[0.2,0.5,0.3],[0.6,0.1,0.3],[0.4,0.4,0.2]]
使用函数结果:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False,name='sparse_categorical_crossentropy')
loss_val = loss_fn(y_true,y_pred).numpy()
loss_val
1.840487
手动计算 SparseCategoricalCrossentropy 结果:
(-np.log(0.2)-np.log(0.1)-np.log(0.2))/3
 1.8404869726207487

(2)使用 Adam 作为优化器,使用 accuracy 作为评估指标。

OUTPUT_CLASSES = 3
EPOCHS = 20
VAL_SUBSPLITS = 5
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE
VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS
model = unet_model(output_channels=OUTPUT_CLASSES)
model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),  metrics=['accuracy'])
model_history = model.fit(train_batches, epochs=EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, validation_steps=VALIDATION_STEPS, validation_data=test_batches)

训练结果输出:

115/115 [==============================] - 110s 961ms/step - loss: 0.1126 - accuracy: 0.9473 - val_loss: 0.3694 - val_accuracy: 0.8897

5. 预测

(1)使用 create_mask 我们会将对该批次的第一张图片的预测掩码图像进行展示,结果是一个大小为 (128, 128, 1) 的向量,其实就是给出了该图片每个像素点的预测标签。

(2)在这里我们使用了上面的一个样本 sample_image ,使用训练好的模型进行预测,因为这里的样本 sample_image 是的大小是 (128, 128, 3) ,我们的模型需要加入 batch_size 维度,所以在第一维扩展了一个维度,大小变为 (1, 128, 128, 3) 才能输入模型。

(3)从绘制的预测掩码图像结果看,预测宠物边界线已经相当清晰了,如果进一步调整模型结果和训练的迭代次数,效果会更加好。

def create_mask(pred_mask):
    pred_mask = tf.math.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]
display([sample_image, sample_mask,  create_mask(model.predict(sample_image[tf.newaxis, ...]))])

以上就是Tensorflow2.10实现图像分割任务示例详解的详细内容,更多关于Tensorflow 图像分割的资料请关注我们其它相关文章!

(0)

相关推荐

  • Tensorflow 2.4 搭建单层和多层 Bi-LSTM 模型

    目录 前言 实现过程 1. 获取数据 2. 处理数据 3. 单层 Bi-LSTM 模型 4. 多层 Bi-LSTM 模型 前言 本文使用 cpu 版本的 TensorFlow 2.4 ,分别搭建单层 Bi-LSTM 模型和多层 Bi-LSTM 模型完成文本分类任务. 确保使用 numpy == 1.19.0 左右的版本,否则在调用 TextVectorization 的时候可能会报 NotImplementedError . 实现过程 1. 获取数据 (1)我们本文用到的数据是电影的影评数据,每

  • Tensorflow2.4从头训练Word Embedding实现文本分类

    目录 前言 具体介绍 1. 三种文本向量化方法 2. 获取数据 3. 处理数据 4. 搭建.训练模型 5. 导出训练好的词嵌入向量 前言 本文主要使用 cpu 版本的 tensorflow 2.4 版本完成文本的 word embedding 训练,并且以此为基础完成影评文本分类任务. 具体介绍 1. 三种文本向量化方法 通常在深度学习模型中我们的输入都是以向量形式存在的,所以我们处理数据过程的重要一项任务就是将文本中的 token (一个 token 可以是英文单词.一个汉字.一个中文词语等,

  • 深度学习Tensorflow 2.4 完成迁移学习和模型微调

    目录 前言 实现过程 1. 获取数据 2. 数据扩充与数据缩放 3. 迁移学习 4. 微调 5. 预测 前言 本文使用 cpu 的 tensorflow 2.4 完成迁移学习和模型微调,并使用训练好的模型完成猫狗图片分类任务. 预训练模型在 NLP 中最常见的可能就是 BERT 了,在 CV 中我们此次用到了 MobileNetV2 ,它也是一个轻量化预训练模型,它已经经过大量的图片分类任务的训练,里面保存了一个可以通用的去捕获图片特征的模型网络结构,其可以通用地提取出图片的有意义特征.这些特征

  • 深度学习TextLSTM的tensorflow1.14实现示例

    目录 对单词最后一个字母的预测 结果打印 对单词最后一个字母的预测 LSTM 的原理自己找,这里只给出简单的示例代码,就是对单词最后一个字母的预测. # LSTM 的原理自己找,这里只给出简单的示例代码 import tensorflow as tf import numpy as np tf.reset_default_graph() # 预测最后一个字母 words = ['make','need','coal','word','love','hate','live','home','has

  • 使用TensorFlow创建生成式对抗网络GAN案例

    目录 导入必要的库和模块 定义训练循环 最后定义主函数 导入必要的库和模块 以下是使用TensorFlow创建一个生成式对抗网络(GAN)的案例: 首先,我们需要导入必要的库和模块: import tensorflow as tf from tensorflow.keras import layers import matplotlib.pyplot as plt import numpy as np 然后,我们定义生成器和鉴别器模型.生成器模型将随机噪声作为输入,并输出伪造的图像.鉴别器模型则

  • 深度学习Tensorflow2.8 使用 BERT 进行文本分类

    目录 前言 1. python 库准备 2. BERT 是什么? 3. 获取并处理 IMDB 数据 4. 初识 TensorFlow Hub 中的 BERT 处理器和模型 5. 搭建模型 6. 训练模型 7. 测试模型 8. 保存模型 9. 重新加载模型并进行预测 前言 本文使用 cpu 版本的 Tensorflow 2.8 ,通过搭建 BERT 模型完成文本分类任务. 1. python 库准备 为了保证能正常运行本文代码,需要保证以下库的版本: tensorflow==2.8.4 tenso

  • 深度学习TextRNN的tensorflow1.14实现示例

    目录 实现对下一个单词的预测 结果打印 实现对下一个单词的预测 RNN 原理自己找,这里只给出简单例子的实现代码 import tensorflow as tf import numpy as np tf.reset_default_graph() sentences = ['i love damao','i like mengjun','we love all'] words = list(set(" ".join(sentences).split())) word2idx = {v

  • SQL实现Excel的10个常用功能的示例详解

    目录 01. 关联公式:Vlookup 02. 对比两列差异 03. 去除重复值 04. 缺失值处理 05. 多条件筛选 06. 模糊筛选数据 07. 分类汇总 08. 条件计算 09. 删除数据间的空格 10. 合并与排序列 SQL笔试题原题 某数据服务公司 某手游公司的SQL笔试题(原题) 某互联网金融公司SQL笔试题(原题) SQL,数据分析岗的必备技能,你可以不懂Python,R,不懂可视化,不懂机器学习.但SQL,你必须懂.要不然领导让你跑个数据来汇......,哦不,你不懂SQL都无

  • AngularJS的Filter的示例详解

    贴上几个有关Filter使用的几个示例. 1. 首先创建一个表格 <body ng-app="app"> <div class="divAll" ng-controller="tableFilter"> <input type="text" placeholder="输入你要搜索的内容" ng-model="key"> <br><br

  • JavaScript中自带的 reduce()方法使用示例详解

    1.方法说明 , Array的reduce()把一个函数作用在这个Array的[x1, x2, x3...]上,这个函数必须接收两个参数,reduce()把结果继续和序列的下一个元素做累积计算,其效果就是: [x1, x2, x3, x4].reduce(f) = f(f(f(x1, x2), x3), x4) 2. 使用示例 'use strict'; function string2int(s){ if(!s){ alert('the params empty'); return; } if

  • ThinkPHP Where 条件中常用表达式示例(详解)

    Where 条件表达式格式为: $map['字段名'] = array('表达式', '操作条件'); 其中 $map 是一个普通的数组变量,可以根据自己需求而命名.上述格式中的表达式实际是运算符的意义: ThinkPHP运算符 与 SQL运算符 对照表 TP运算符 SQL运算符 例子 实际查询条件 eq = $map['id'] = array('eq',100); 等效于:$map['id'] = 100; neq != $map['id'] = array('neq',100); id !

  • Python网络爬虫中的同步与异步示例详解

    一.同步与异步 #同步编程(同一时间只能做一件事,做完了才能做下一件事情) <-a_url-><-b_url-><-c_url-> #异步编程 (可以近似的理解成同一时间有多个事情在做,但有先后) <-a_url-> <-b_url-> <-c_url-> <-d_url-> <-e_url-> <-f_url-> <-g_url-> <-h_url-> <--i_ur

  • 常用JavaScript正则表达式汇编与示例详解

    1.1 前言 目前收集整理了21个常用的javaScript正则表达式,其中包括用户名.密码强度.整数.数字.电子邮件地址(Email).手机号码.身份证号.URL地址. IP地址. 十六进制颜色. 日期. 微信号.车牌号.中文正则等.表单验证处理必备,赶紧收藏吧! 还会陆续加入新的正则进来,大家多提宝贵意见! 2.1 用户名正则 2.1.1 基本用户名正则 在做用户注册时,都会用到用户名正则校验. 定义基本用户名命名规则如下: 最短4位,最长16位 {4,16} 可以包含小写大母 [a-z]

  • Python3的高阶函数map,reduce,filter的示例详解

    函数的参数能接收变量,那么一个函数就可以接收另一个函数作为参数,这种函数就称之为高阶函数. 注意其中:map和filter返回一个惰性序列,可迭代对象,需要转化为list >>> a = 3.1415 >>> round(a,2) 3.14 >>> a_round = round >>> a_round(a,2) 3.14 >>> def func_devide(x, y, f): return f(x) - f(y

  • 示例详解Python3 or Python2 两者之间的差异

    每门编程语言在发布更新之后,主要版本之间都会发生很大的变化. 在本文中,Vinodh Kumar 通过示例解释了 Python 2 和 Python 3 之间的一些重大差异,以帮助说明语言的变化. 本教程主要介绍内容: 表达式 Print 选项 Unequal 操作 Range 自动迁移 性能问题 主要的内部事务更改 1.表达式 在 Python 2 中为获得计算表达式,你会键入: 但在 Python 3 中,你会键入: 因此,无论我们输入什么,值都会分配给 2 和 3 中的变量 x.当在 Py

  • Spring 缓存抽象示例详解

    Spring缓存抽象概述 Spring框架自身并没有实现缓存解决方案,但是从3.1开始定义了org.springframework.cache.Cache和org.springframework.cache.CacheManager接口,提供对缓存功能的声明,能够与多种流行的缓存实现集成. Cache接口为缓存的组件规范定义,包含缓存的各种操作集合: Cache接口下Spring提供了各种xxxCache的实现:如RedisCache,EhCacheCache , ConcurrentMapCa

  • 编译安装redisd的方法示例详解

    安装方法: yum安装 查看yum仓库redis版本 [root@centos ~]# yum list redis Loaded plugins: fastestmirror, langpacks Loading mirror speeds from cached hostfile Available Packages redis.x86_64 3.2.12-2.el7 myepel yum安装 [root@centos ~]# yum install redis -y 启动服务并设为开机启动

随机推荐