Pytorch图像处理注意力机制解析及代码详解

什么是注意力机制

注意力机制是一个非常有效的trick,注意力机制的实现方式有许多,我们一起来学习一下

注意力机制是深度学习常用的一个小技巧,它有多种多样的实现形式,尽管实现方式多样,但是每一种注意力机制的实现的核心都是类似的,就是注意力。

注意力机制的核心重点就是让网络关注到它更需要关注的地方。

当我们使用卷积神经网络去处理图片的时候,我们会更希望卷积神经网络去注意应该注意的地方,而不是什么都关注,我们不可能手动去调节需要注意的地方,这个时候,如何让卷积神经网络去自适应的注意重要的物体变得极为重要。

注意力机制就是实现网络自适应注意的一个方式。

一般而言,注意力机制可以分为通道注意力机制,空间注意力机制,以及二者的结合。

代码下载

复制该路径到地址栏跳转。

注意力机制的实现方式

在深度学习中,常见的注意力机制的实现方式有SENet,CBAM,ECA等等。

1、SENet的实现

SENet是通道注意力机制的典型实现。

2017年提出的SENet是最后一届ImageNet竞赛的冠军,其实现示意图如下所示,对于输入进来的特征层,我们关注其每一个通道的权重,对于SENet而言,其重点是获得输入进来的特征层,每一个通道的权值。利用SENet,我们可以让网络关注它最需要关注的通道。

其具体实现方式就是:

1、对输入进来的特征层进行全局平均池化。

2、然后进行两次全连接,第一次全连接神经元个数较少,第二次全连接神经元个数和输入特征层相同。

3、在完成两次全连接后,我们再取一次Sigmoid将值固定到0-1之间,此时我们获得了输入特征层每一个通道的权值(0-1之间)。

4、在获得这个权值后,我们将这个权值乘上原输入特征层即可。

实现代码如下:

import torch
import torch.nn as nn
import math
class se_block(nn.Module):
    def __init__(self, channel, ratio=16):
        super(se_block, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
                nn.Linear(channel, channel // ratio, bias=False),
                nn.ReLU(inplace=True),
                nn.Linear(channel // ratio, channel, bias=False),
                nn.Sigmoid()
        )
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

2、CBAM的实现

CBAM将通道注意力机制和空间注意力机制进行一个结合,相比于SENet只关注通道的注意力机制可以取得更好的效果。其实现示意图如下所示,CBAM会对输入进来的特征层,分别进行通道注意力机制的处理和空间注意力机制的处理。

下图是通道注意力机制和空间注意力机制的具体实现方式:

图像的上半部分为通道注意力机制,通道注意力机制的实现可以分为两个部分,我们会对输入进来的单个特征层,分别进行全局平均池化和全局最大池化。之后对平均池化和最大池化的结果,利用共享的全连接层进行处理,我们会对处理后的两个结果进行相加,然后取一个sigmoid,此时我们获得了输入特征层每一个通道的权值(0-1之间)。

在获得这个权值后,我们将这个权值乘上原输入特征层即可。

图像的下半部分为空间注意力机制,我们会对输入进来的特征层,在每一个特征点的通道上取最大值和平均值。之后将这两个结果进行一个堆叠,利用一次通道数为1的卷积调整通道数,然后取一个sigmoid,此时我们获得了输入特征层每一个特征点的权值(0-1之间)。

在获得这个权值后,我们将这个权值乘上原输入特征层即可。

实现代码如下:

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=8):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        # 利用1x1卷积代替全连接
        self.fc1   = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2   = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)
class cbam_block(nn.Module):
    def __init__(self, channel, ratio=8, kernel_size=7):
        super(cbam_block, self).__init__()
        self.channelattention = ChannelAttention(channel, ratio=ratio)
        self.spatialattention = SpatialAttention(kernel_size=kernel_size)
    def forward(self, x):
        x = x * self.channelattention(x)
        x = x * self.spatialattention(x)
        return x

3、ECA的实现

ECANet是也是通道注意力机制的一种实现形式。ECANet可以看作是SENet的改进版。

ECANet的作者认为SENet对通道注意力机制的预测带来了副作用,捕获所有通道的依赖关系是低效并且是不必要的。

在ECANet的论文中,作者认为卷积具有良好的跨通道信息获取能力。

ECA模块的思想是非常简单的,它去除了原来SE模块中的全连接层,直接在全局平均池化之后的特征上通过一个1D卷积进行学习。

既然使用到了1D卷积,那么1D卷积的卷积核大小的选择就变得非常重要了,了解过卷积原理的同学很快就可以明白,1D卷积的卷积核大小会影响注意力机制每个权重的计算要考虑的通道数量。用更专业的名词就是跨通道交互的覆盖率。

如下图所示,左图是常规的SE模块,右图是ECA模块。ECA模块用1D卷积替换两次全连接。

实现代码如下:

class eca_block(nn.Module):
    def __init__(self, channel, b=1, gamma=2):
        super(eca_block, self).__init__()
        kernel_size = int(abs((math.log(channel, 2) + b) / gamma))
        kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
        y = self.sigmoid(y)
        return x * y.expand_as(x)

注意力机制的应用

注意力机制是一个即插即用的模块,理论上可以放在任何一个特征层后面,可以放在主干网络,也可以放在加强特征提取网络。

由于放置在主干会导致网络的预训练权重无法使用,本文以YoloV4-tiny为例,将注意力机制应用加强特征提取网络上。

如下图所示,我们在主干网络提取出来的两个有效特征层上增加了注意力机制,同时对上采样后的结果增加了注意力机制。

实现代码如下:

attention_block = [se_block, cbam_block, eca_block]
#---------------------------------------------------#
#   特征层->最后的输出
#---------------------------------------------------#
class YoloBody(nn.Module):
    def __init__(self, anchors_mask, num_classes, phi=0):
        super(YoloBody, self).__init__()
        self.phi            = phi
        self.backbone       = darknet53_tiny(None)
        self.conv_for_P5    = BasicConv(512,256,1)
        self.yolo_headP5    = yolo_head([512, len(anchors_mask[0]) * (5 + num_classes)],256)
        self.upsample       = Upsample(256,128)
        self.yolo_headP4    = yolo_head([256, len(anchors_mask[1]) * (5 + num_classes)],384)
        if 1 <= self.phi and self.phi <= 3:
            self.feat1_att      = attention_block[self.phi - 1](256)
            self.feat2_att      = attention_block[self.phi - 1](512)
            self.upsample_att   = attention_block[self.phi - 1](128)
    def forward(self, x):
        #---------------------------------------------------#
        #   生成CSPdarknet53_tiny的主干模型
        #   feat1的shape为26,26,256
        #   feat2的shape为13,13,512
        #---------------------------------------------------#
        feat1, feat2 = self.backbone(x)
        if 1 <= self.phi and self.phi <= 3:
            feat1 = self.feat1_att(feat1)
            feat2 = self.feat2_att(feat2)
        # 13,13,512 -> 13,13,256
        P5 = self.conv_for_P5(feat2)
        # 13,13,256 -> 13,13,512 -> 13,13,255
        out0 = self.yolo_headP5(P5)
        # 13,13,256 -> 13,13,128 -> 26,26,128
        P5_Upsample = self.upsample(P5)
        # 26,26,256 + 26,26,128 -> 26,26,384
        if 1 <= self.phi and self.phi <= 3:
            P5_Upsample = self.upsample_att(P5_Upsample)
        P4 = torch.cat([P5_Upsample,feat1],axis=1)
        # 26,26,384 -> 26,26,256 -> 26,26,255
        out1 = self.yolo_headP4(P4)
        return out0, out1

以上就是Pytorch图像处理注意力机制解析及代码详解的详细内容,更多关于Pytorch图像处理注意力机制的资料请关注我们其它相关文章!

(0)

相关推荐

  • pytorch 带batch的tensor类型图像显示操作

    项目场景 pytorch训练时我们一般把数据集放到数据加载器里,然后分批拿出来训练.训练前我们一般还要看一下训练数据长啥样,也就是训练数据集可视化. 那么如何显示dataloader里面带batch的tensor类型的图像呢? 显示图像 绘图最常用的库就是matplotlib: pip install matplotlib 显示图像会用到matplotlib.pyplot.imshow方法.查阅官方文档可知,该方法接收的图像的通道数要放到后面: 数据加载器中数据的维度是[B, C, H, W],

  • pytorch加载自己的图像数据集实例

    之前学习深度学习算法,都是使用网上现成的数据集,而且都有相应的代码.到了自己开始写论文做实验,用到自己的图像数据集的时候,才发现无从下手 ,相信很多新手都会遇到这样的问题. 参考文章https://www.jb51.net/article/177613.htm 下面代码实现了从文件夹内读取所有图片,进行归一化和标准化操作并将图片转化为tensor.最后读取第一张图片并显示. # 数据处理 import os import torch from torch.utils import data fr

  • Pytorch实现图像识别之数字识别(附详细注释)

    使用了两个卷积层加上两个全连接层实现 本来打算从头手撕的,但是调试太耗时间了,改天有时间在从头写一份 详细过程看代码注释,参考了下一个博主的文章,但是链接没注意关了找不到了,博主看到了联系下我,我加上 代码相关的问题可以评论私聊,也可以翻看博客里的文章,部分有详细解释 Python实现代码: import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transf

  • Pytorch 使用CNN图像分类的实现

    需求 在4*4的图片中,比较外围黑色像素点和内圈黑色像素点个数的大小将图片分类 如上图图片外围黑色像素点5个大于内圈黑色像素点1个分为0类反之1类 想法 通过numpy.PIL构造4*4的图像数据集 构造自己的数据集类 读取数据集对数据集选取减少偏斜 cnn设计因为特征少,直接1*1卷积层 或者在4*4外围添加padding成6*6,设计2*2的卷积核得出3*3再接上全连接层 代码 import torch import torchvision import torchvision.transf

  • pytorch实现图像识别(实战)

    目录 1. 代码讲解 1.1 导库 1.2 标准化.transform.设置GPU 1.3 预处理数据 1.4 建立模型 1.5 训练模型 1.6 测试模型 1.7结果 1. 代码讲解 1.1 导库 import os.path from os import listdir import numpy as np import pandas as pd from PIL import Image import torch import torch.nn as nn import torch.nn.

  • Pytorch图像处理注意力机制解析及代码详解

    什么是注意力机制 注意力机制是一个非常有效的trick,注意力机制的实现方式有许多,我们一起来学习一下 注意力机制是深度学习常用的一个小技巧,它有多种多样的实现形式,尽管实现方式多样,但是每一种注意力机制的实现的核心都是类似的,就是注意力. 注意力机制的核心重点就是让网络关注到它更需要关注的地方. 当我们使用卷积神经网络去处理图片的时候,我们会更希望卷积神经网络去注意应该注意的地方,而不是什么都关注,我们不可能手动去调节需要注意的地方,这个时候,如何让卷积神经网络去自适应的注意重要的物体变得极为

  • Django权限机制实现代码详解

    本文研究的主要是Django权限机制的相关内容,具体如下. 1. Django权限机制概述 权限机制能够约束用户行为,控制页面的显示内容,也能使API更加安全和灵活:用好权限机制,能让系统更加强大和健壮.因此,基于Django的开发,理清Django权限机制是非常必要的. 1.1 Django的权限控制 Django用user, group和permission完成了权限机制,这个权限机制是将属于model的某个permission赋予user或group,可以理解为全局的权限,即如果用户A对数

  • Java多线程之线程通信生产者消费者模式及等待唤醒机制代码详解

    前言 前面的例子都是多个线程在做相同的操作,比如4个线程都对共享数据做tickets–操作.大多情况下,程序中需要不同的线程做不同的事,比如一个线程对共享变量做tickets++操作,另一个线程对共享变量做tickets–操作,这就是大名鼎鼎的生产者和消费者模式. 正文 一,生产者-消费者模式也是多线程 生产者和消费者模式也是多线程的范例.所以其编程需要遵循多线程的规矩. 首先,既然是多线程,就必然要使用同步.上回说到,synchronized关键字在修饰函数的时候,使用的是"this"

  • python数字图像处理之高级滤波代码详解

    本文提供许多的滤波方法,这些方法放在filters.rank子模块内. 这些方法需要用户自己设定滤波器的形状和大小,因此需要导入morphology模块来设定. 1.autolevel 这个词在photoshop里面翻译成自动色阶,用局部直方图来对图片进行滤波分级. 该滤波器局部地拉伸灰度像素值的直方图,以覆盖整个像素值范围. 格式:skimage.filters.rank.autolevel(image, selem) selem表示结构化元素,用于设定滤波器. from skimage im

  • 通过反射实现Java下的委托机制代码详解

    简述 一直对Java没有现成的委托机制耿耿于怀,所幸最近有点时间,用反射写了一个简单的委托模块,以供参考. 模块API public Class Delegater()//空参构造,该类管理委托实例并实现委托方法 //添加一个静态方法委托,返回整型值ID代表该方法与参数构成的实例.若失败,则返回-1. public synchronized int addFunctionDelegate(Class<?> srcClass,String methodName,Object... params)

  • java中注解机制及其原理的详解

    java中注解机制及其原理的详解 什么是注解 注解也叫元数据,例如我们常见的@Override和@Deprecated,注解是JDK1.5版本开始引入的一个特性,用于对代码进行说明,可以对包.类.接口.字段.方法参数.局部变量等进行注解.它主要的作用有以下四方面: 生成文档,通过代码里标识的元数据生成javadoc文档. 编译检查,通过代码里标识的元数据让编译器在编译期间进行检查验证. 编译时动态处理,编译时通过代码里标识的元数据动态处理,例如动态生成代码. 运行时动态处理,运行时通过代码里标识

  • apache commons工具集代码详解

    Apache Commons包含了很多开源的工具,用于解决平时编程经常会遇到的问题,减少重复劳动.下面是我这几年做开发过程中自己用过的工具类做简单介绍. 组件 功能介绍 BeanUtils 提供了对于JavaBean进行各种操作,克隆对象,属性等等. Betwixt XML与Java对象之间相互转换. Codec 处理常用的编码方法的工具类包 例如DES.SHA1.MD5.Base64等. Collections java集合框架操作. Compress java提供文件打包 压缩类库. Con

  • Jedis对redis的五大类型操作代码详解

    本篇主要阐述Jedis对redis的五大类型的操作:字符串.列表.散列.集合.有序集合. JedisUtil 这里的测试用例采用junit4进行运行,准备代码如下: private static final String ipAddr = "10.10.195.112"; private static final int port = 6379; private static Jedis jedis= null; @BeforeClass public static void init

  • Linux下tcpdump命令解析及使用详解

    简介 用简单的话来定义tcpdump,就是:dump the traffic on a network,根据使用者的定义对网络上的数据包进行截获的包分析工具.tcpdump可以将网络中传送的数据包的"头"完全截获下来提供分析.它支持针对网络层.协议.主机.网络或端口的过滤,并提供and.or.not等逻辑语句来帮助你去掉无用的信息. 实用命令实例 默认启动 tcpdump 普通情况下,直接启动tcpdump将监视第一个网络接口上所有流过的数据包. 监视指定网络接口的数据包 tcpdum

  • SQL行转列和列转行代码详解

    行列互转,是一个经常遇到的需求.实现的方法,有case when方式和2005之后的内置pivot和unpivot方法来实现. 在读了技术内幕那一节后,虽说这些解决方案早就用过了,却没有系统性的认识和总结过.为了加深认识,再总结一次. 行列互转,可以分为静态互转,即事先就知道要处理多少行(列);动态互转,事先不知道处理多少行(列). --创建测试环境 USE tempdb; GO IF OBJECT_ID('dbo.Orders') IS NOT NULL DROP TABLE dbo.Orde

随机推荐