利用Pytorch实现获取特征图的方法详解

目录
  • 简单加载官方预训练模型
  • 图片预处理
  • 提取单个特征图
  • 提取多个特征图

简单加载官方预训练模型

torchvision.models预定义了很多公开的模型结构

如果pretrained参数设置为False,那么仅仅设定模型结构;如果设置为True,那么会启动一个下载流程,下载预训练参数

如果只想调用模型,不想训练,那么设置model.eval()和model.requires_grad_(False)

想查看模型参数可以使用modules和named_modules,其中named_modules是一个长度为2的tuple,第一个变量是name,第二个变量是module本身。

# -*- coding: utf-8 -*-
from torch import nn
from torchvision import models

# load model. If pretrained is True, there will be a downloading process
model = models.vgg19(pretrained=True)
model.eval()
model.requires_grad_(False)

# get model component
features = model.features
modules = features.modules()
named_modules = features.named_modules()

# print modules
for module in modules:
    if isinstance(module, nn.Conv2d):
        weight = module.weight
        bias = module.bias
        print(module, weight.shape, bias.shape,
              weight.requires_grad, bias.requires_grad)
    elif isinstance(module, nn.ReLU):
        print(module)

print()
for named_module in named_modules:
    name = named_module[0]
    module = named_module[1]
    if isinstance(module, nn.Conv2d):
        weight = module.weight
        bias = module.bias
        print(name, module, weight.shape, bias.shape,
              weight.requires_grad, bias.requires_grad)
    elif isinstance(module, nn.ReLU):
        print(name, module)

图片预处理

使用opencv和pil读图都可以使用transforms.ToTensor()把原本[H, W, 3]的数据转成[3, H, W]的tensor。但opencv要注意把数据改成RGB顺序。

vgg系列模型需要做normalization,建议配合torchvision.transforms来实现。

mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].

参考:https://pytorch.org/hub/pytorch_vision_vgg/

# -*- coding: utf-8 -*-
from PIL import Image
import cv2
import torch
from torchvision import transforms

# transforms for preprocess
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# load image using cv2
image_cv2 = cv2.imread('lena_std.bmp')
image_cv2 = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
image_cv2 = preprocess(image_cv2)

# load image using pil
image_pil = Image.open('lena_std.bmp')
image_pil = preprocess(image_pil)

# check whether image_cv2 and image_pil are same
print(torch.all(image_cv2 == image_pil))
print(image_cv2.shape, image_pil.shape)

提取单个特征图

如果只提取单层特征图,可以把模型截断,以节省算力和显存消耗。

下面索引之所以有+1是因为pytorch预训练模型里面第一个索引的module总是完整模块结构,第二个才开始子模块。

# -*- coding: utf-8 -*-
from PIL import Image
from torchvision import models
from torchvision import transforms

# load model. If pretrained is True, there will be a downloading process
model = models.vgg19(pretrained=True)
model = model.features[:16 + 1]  # 16 = conv3_4
model.eval()
model.requires_grad_(False)
model.to('cuda')
print(model)

# load and preprocess image
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
    transforms.Resize(size=(224, 224))
])
image = Image.open('lena_std.bmp')
image = preprocess(image)
inputs = image.unsqueeze(0)  # add batch dimension
inputs = inputs.cuda()

# forward
output = model(inputs)
print(output.shape)

提取多个特征图

第一种方式:逐层运行model,如果碰到了需要保存的feature map就存下来。

第二种方式:使用register_forward_hook,使用这种方式需要用一个类把feature map以成员变量的形式缓存下来。

两种方式的运行效率差不多

第一种方式简单直观,但是只能处理类似VGG这种没有跨层连接的网络;第二种方式更加通用。

# -*- coding: utf-8 -*-
from PIL import Image
import torch
from torchvision import models
from torchvision import transforms

# load model. If pretrained is True, there will be a downloading process
model = models.vgg19(pretrained=True)
model = model.features[:16 + 1]  # 16 = conv3_4
model.eval()
model.requires_grad_(False)
model.to('cuda')

# check module name
for named_module in model.named_modules():
    name = named_module[0]
    module = named_module[1]
    print('-------- %s --------' % name)
    print(module)
    print()

# load and preprocess image
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
    transforms.Resize(size=(224, 224))
])
image = Image.open('lena_std.bmp')
image = preprocess(image)
inputs = image.unsqueeze(0)  # add batch dimension
inputs = inputs.cuda()

# forward - 1
layers = [2, 7, 8, 9, 16]
layers = sorted(set(layers))
feature_maps = {}
feature = inputs
for i in range(max(layers) + 1):
    feature = model[i](feature)
    if i in layers:
        feature_maps[i] = feature
for key in feature_maps:
    print(key, feature_maps.get(key).shape)

# forward - 2
class FeatureHook:
    def __init__(self, module):
        self.inputs = None
        self.output = None
        self.hook = module.register_forward_hook(self.get_features)

    def get_features(self, module, inputs, output):
        self.inputs = inputs
        self.output = output

layer_names = ['2', '7', '8', '9', '16']
hook_modules = []
for named_module in model.named_modules():
    name = named_module[0]
    module = named_module[1]
    if name in layer_names:
        hook_modules.append(module)

hooks = [FeatureHook(module) for module in hook_modules]
output = model(inputs)
features = [hook.output for hook in hooks]
for feature in features:
    print(feature.shape)

# check correctness
for i, layer in enumerate(layers):
    feature1 = feature_maps.get(layer)
    feature2 = features[i]
    print(torch.all(feature1 == feature2))

使用第二种方式(register_forward_hook),resnet特征图也可以顺利拿到。

而由于resnet的model已经不可以用model[i]的形式索引,所以无法使用第一种方式。

# -*- coding: utf-8 -*-
from PIL import Image
from torchvision import models
from torchvision import transforms

# load model. If pretrained is True, there will be a downloading process
model = models.resnet18(pretrained=True)
model.eval()
model.requires_grad_(False)
model.to('cuda')

# check module name
for named_module in model.named_modules():
    name = named_module[0]
    module = named_module[1]
    print('-------- %s --------' % name)
    print(module)
    print()

# load and preprocess image
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
    transforms.Resize(size=(224, 224))
])
image = Image.open('lena_std.bmp')
image = preprocess(image)
inputs = image.unsqueeze(0)  # add batch dimension
inputs = inputs.cuda()

class FeatureHook:
    def __init__(self, module):
        self.inputs = None
        self.output = None
        self.hook = module.register_forward_hook(self.get_features)

    def get_features(self, module, inputs, output):
        self.inputs = inputs
        self.output = output

layer_names = [
    'conv1',
    'layer1.0.relu',
    'layer2.0.conv1'
]

hook_modules = []
for named_module in model.named_modules():
    name = named_module[0]
    module = named_module[1]
    if name in layer_names:
        hook_modules.append(module)

hooks = [FeatureHook(module) for module in hook_modules]
output = model(inputs)
features = [hook.output for hook in hooks]
for feature in features:
    print(feature.shape)

问题来了,resnet这种类型的网络结构怎么截断?

使用如下命令就可以,print查看需要截断到哪里,然后用nn.Sequential重组即可。

需注意重组后网络的module_name会发生变化。

print(list(model.children())
model = torch.nn.Sequential(*list(model.children())[:6])

以上就是利用Pytorch实现获取特征图的方法详解的详细内容,更多关于Pytorch获取特征图的资料请关注我们其它相关文章!

(0)

相关推荐

  • python实现拉普拉斯特征图降维示例

    这种方法假设样本点在光滑的流形上,这一方法的计算数据的低维表达,局部近邻信息被最优的保存.以这种方式,可以得到一个能反映流形的几何结构的解. 步骤一:构建一个图G=(V,E),其中V={vi,i=1,2,3-n}是顶点的集合,E={eij}是连接顶点的vi和vj边,图的每一个节点vi与样本集X中的一个点xi相关.如果xi,xj相距较近,我们就连接vi,vj.也就是说在各自节点插入一个边eij,如果Xj在xi的k领域中,k是定义参数. 步骤二:每个边都与一个权值Wij相对应,没有连接点之间的权值为

  • Python基于Pytorch的特征图提取实例

    目录 简述 单个图片的提取 神经网络的构建 特征图的提取 可视化展示 完整代码 总结 简述 为了方便理解卷积神经网络的运行过程,需要对卷积神经网络的运行结果进行可视化的展示. 大致可分为如下步骤: 单个图片的提取 神经网络的构建 特征图的提取 可视化展示 单个图片的提取 根据目标要求,需要对单个图片进行卷积运算,但是Pytorch中读取数据主要用到torch.utils.data.DataLoader类,因此我们需要编写单个图片的读取程序 def get_picture(picture_dir,

  • 使用pytorch提取卷积神经网络的特征图可视化

    目录 前言 1. 效果图 2. 完整代码 3. 代码说明 4. 可视化梯度,feature 总结 前言 文章中的代码是参考基于Pytorch的特征图提取编写的代码本身很简单这里只做简单的描述. 1. 效果图 先看效果图(第一张是原图,后面的都是相应的特征图,这里使用的网络是resnet50,需要注意的是下面图片显示的特征图是经过放大后的图,原图是比较小的图,因为太小不利于我们观察): 2. 完整代码 import os import torch import torchvision as tv

  • 利用Pytorch实现获取特征图的方法详解

    目录 简单加载官方预训练模型 图片预处理 提取单个特征图 提取多个特征图 简单加载官方预训练模型 torchvision.models预定义了很多公开的模型结构 如果pretrained参数设置为False,那么仅仅设定模型结构:如果设置为True,那么会启动一个下载流程,下载预训练参数 如果只想调用模型,不想训练,那么设置model.eval()和model.requires_grad_(False) 想查看模型参数可以使用modules和named_modules,其中named_modul

  • Java利用Request请求获取IP地址的方法详解

    前言 最近在项目中遇到一个需求,是需要将不同省份的用户,展示不同内容,通过查找相关的资料,发现可以通过Request请求获取IP地址,下面我们先来贴代码, 如果你要在生产环境使用就直接拿去用吧,我这边已经上线了. 示例代码 public class IpAdrressUtil { /** * 获取Ip地址 * @param request * @return */ private static String getIpAdrress(HttpServletRequest request) { S

  • 利用JavaScript获取用户IP属地方法详解

    目录 写在前面 尝试一:navigator.geolocation 尝试二:sohu 的接口 尝试三:百度地图的接口 写在后面 写在前面 想要像一些平台那样显示用户的位置信息,例如某省市那样.那么这是如何做到的, 据说这个位置信息的准确性在通信网络运营商那里?先不管,先实践尝试下能不能获取. 尝试一:navigator.geolocation 尝试了使用 navigator.geolocation,但未能成功拿到信息. getGeolocation(){ if ('geolocation' in

  • pytorch对可变长度序列的处理方法详解

    主要是用函数torch.nn.utils.rnn.PackedSequence()和torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来进行的,分别来看看这三个函数的用法. 1.torch.nn.utils.rnn.PackedSequence() NOTE: 这个类的实例不能手动创建.它们只能被 pack_padded_sequence() 实例化. PackedSequence

  • pytorch的梯度计算以及backward方法详解

    基础知识 tensors: tensor在pytorch里面是一个n维数组.我们可以通过指定参数reuqires_grad=True来建立一个反向传播图,从而能够计算梯度.在pytorch中一般叫做dynamic computation graph(DCG)--即动态计算图. import torch import numpy as np # 方式一 x = torch.randn(2,2, requires_grad=True) # 方式二 x = torch.autograd.Variabl

  • Matlab绘制雨云图的方法详解

    目录 介绍 横向雨云图 纵向雨云图 介绍 写了俩代码模板,用来绘制横向云雨图与纵向云雨图,云雨图其实就是用把小提琴图拆开来的模板,想获取小提琴图绘制函数的可以看这里:基于Matlab绘制小提琴图的示例代码 后面的俩模板用的时候只需要换换数据,颜色及每一类名称即可,雨云图绘制效果如下: 横向雨云图 function rainCloudsTMPL1 % @author: slandarer % 在这里放入你的数据============================================

  • Vue利用openlayers实现点击弹窗的方法详解

    目录 解释 编写弹窗 引入 openlayer使用弹窗组件 点击事件 这个写的稍微简单一点就行了,其实呢,这个不是很难,主要是知道原理就可以了. 我想实现的内容是什么意思呢?就是说页面上有很多坐标点,点击坐标点的时候在相应的位置弹出一个框,然后框里显示出这个坐标点的相关数据. 解释 这个内容的其实就是添加一个弹窗图层,然后在点击的时候让他显示出来罢了. 编写弹窗 首先一点,我们这个弹窗需要自己写一下,具体的样式,展示的内容之类的,所以说写一个弹窗组件,然后在openlayer文件中引用加载. 比

  • EasyX绘制透明背景图的方法详解

    目录 三元光栅操作 优化方案 三元光栅操作 根据在网上的搜索总结得到两种方案,最常见的绘制带有透明背景的图像的方案都是采用如下的源图像和掩码图像叠加来消去边缘部分: IMAGE img[2]; loadimage(&img[0], "sun1.png", 100, 100); // 掩码图像 loadimage(&img[1], "sun0.png", 100, 100); // 源图像 putimage(0, 0, &img[0], NOT

  • 利用Android实现光影流动特效的方法详解

    目录 前言 MaskFilter 类简介 MaskFilter 的几种效果对比 光影流动 光影流动效果1 光影流动效果2 光影流动效果3 光影流动效果4:光影沿贝塞尔曲线流动 总结 前言 Flutter 的画笔类 Paint 提供了很多图形绘制的配置属性,来供我们绘制更丰富多彩的图形.前面几篇我们介绍了 shader 属性来绘制全屏渐变的聊天气泡背景.渐变流动的边框和毛玻璃效果的背景图片,具体可以参考下面几篇文章. 让你的聊天气泡丰富多彩! 手把手教你实现一个流动的渐变色边框 利用光影变化构建立

  • 利用Vue3实现可复制表格的方法详解

    目录 前言 最基础的表格封装 实现复制功能 处理表格中的不可复制元素 测试 前言 表格是前端非常常用的一个控件,但是每次都使用v-for指令手动绘制tr/th/td这些元素是非常麻烦的.同时,基础的 table 样式通常也是不满足需求的,因此一个好的表格封装就显得比较重要了. 最基础的表格封装 最基础基础的表格封装所要做的事情就是让用户只关注行和列的数据,而不需要关注 DOM 结构是怎样的,我们可以参考 AntDesign,columns dataSource 这两个属性是必不可少的,代码如下:

随机推荐