分享Pytorch获取中间层输出的3种方法

目录
  • 【1】方法一:获取nn.Sequential的中间层输出
  • 【2】方法二:IntermediateLayerGetter
  • 【3】方法三:钩子

【1】方法一:获取nn.Sequential的中间层输出

import torch
import torch.nn as nn
model = nn.Sequential(
            nn.Conv2d(3, 9, 1, 1, 0, bias=False),
            nn.BatchNorm2d(9),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1)),
        )

# 假如想要获得ReLu的输出
x = torch.rand([2, 3, 224, 224])
for i in range(len(model)):
    x = model[i](x)
    if i == 2:
        ReLu_out = x
print('ReLu_out.shape:\n\t',ReLu_out.shape)
print('x.shape:\n\t',x.shape)

结果:

ReLu_out.shape:
  torch.Size([2, 9, 224, 224])
x.shape:
  torch.Size([2, 9, 1, 1])

【2】方法二:IntermediateLayerGetter

from collections import OrderedDict
 
import torch
from torch import nn
 
 
class IntermediateLayerGetter(nn.ModuleDict):
    """
    Module wrapper that returns intermediate layers from a model
    It has a strong assumption that the modules have been registered
    into the model in the same order as they are used.
    This means that one should **not** reuse the same nn.Module
    twice in the forward if you want this to work.
    Additionally, it is only able to query submodules that are directly
    assigned to the model. So if `model` is passed, `model.feature1` can
    be returned, but not `model.feature1.layer2`.
    Arguments:
        model (nn.Module): model on which we will extract the features
        return_layers (Dict[name, new_name]): a dict containing the names
            of the modules for which the activations will be returned as
            the key of the dict, and the value of the dict is the name
            of the returned activation (which the user can specify).
    """
    
    def __init__(self, model, return_layers):
        if not set(return_layers).issubset([name for name, _ in model.named_children()]):
            raise ValueError("return_layers are not present in model")
 
        orig_return_layers = return_layers
        return_layers = {k: v for k, v in return_layers.items()}
        layers = OrderedDict()
        for name, module in model.named_children():
            layers[name] = module
            if name in return_layers:
                del return_layers[name]
            if not return_layers:
                break
 
        super(IntermediateLayerGetter, self).__init__(layers)
        self.return_layers = orig_return_layers
 
    def forward(self, x):
        out = OrderedDict()
        for name, module in self.named_children():
            x = module(x)
            if name in self.return_layers:
                out_name = self.return_layers[name]
                out[out_name] = x
        return out
# example
m = torchvision.models.resnet18(pretrained=True)
# extract layer1 and layer3, giving as names `feat1` and feat2`
new_m = torchvision.models._utils.IntermediateLayerGetter(m,{'layer1': 'feat1', 'layer3': 'feat2'})
out = new_m(torch.rand(1, 3, 224, 224))
print([(k, v.shape) for k, v in out.items()])
# [('feat1', torch.Size([1, 64, 56, 56])), ('feat2', torch.Size([1, 256, 14, 14]))]

作用:

在定义它的时候注明作用的模型(如下例中的m)和要返回的layer(如下例中的layer1,layer3),得到new_m。

使用时喂输入变量,返回的就是对应的layer

举例:

m = torchvision.models.resnet18(pretrained=True)
 # extract layer1 and layer3, giving as names `feat1` and feat2`
new_m = torchvision.models._utils.IntermediateLayerGetter(m,{'layer1': 'feat1', 'layer3': 'feat2'})
out = new_m(torch.rand(1, 3, 224, 224))
print([(k, v.shape) for k, v in out.items()])

输出结果:

[('feat1', torch.Size([1, 64, 56, 56])), ('feat2', torch.Size([1, 256, 14, 14]))]

【3】方法三:钩子

class TestForHook(nn.Module):
    def __init__(self):
        super().__init__()

        self.linear_1 = nn.Linear(in_features=2, out_features=2)
        self.linear_2 = nn.Linear(in_features=2, out_features=1)
        self.relu = nn.ReLU()
        self.relu6 = nn.ReLU6()
        self.initialize()

    def forward(self, x):
        linear_1 = self.linear_1(x)
        linear_2 = self.linear_2(linear_1)
        relu = self.relu(linear_2)
        relu_6 = self.relu6(relu)
        layers_in = (x, linear_1, linear_2)
        layers_out = (linear_1, linear_2, relu)
        return relu_6, layers_in, layers_out

features_in_hook = []
features_out_hook = []

def hook(module, fea_in, fea_out):
    features_in_hook.append(fea_in)
    features_out_hook.append(fea_out)
    return None

net = TestForHook()

第一种写法,按照类型勾,但如果有重复类型的layer比较复杂

net_chilren = net.children()
for child in net_chilren:
    if not isinstance(child, nn.ReLU6):
        child.register_forward_hook(hook=hook)

推荐下面我改的这种写法,因为我自己的网络中,在Sequential中有很多层,
这种方式可以直接先print(net)一下,找出自己所需要那个layer的名称,按名称勾出来

layer_name = 'relu_6'
for (name, module) in net.named_modules():
    if name == layer_name:
        module.register_forward_hook(hook=hook)

print(features_in_hook)  # 勾的是指定层的输入
print(features_out_hook)  # 勾的是指定层的输出

到此这篇关于分享Pytorch获取中间层输出的3种方法的文章就介绍到这了,更多相关Pytorch获取中间层输出方法内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • pytorch在fintune时将sequential中的层输出方法,以vgg为例

    有时候我们在fintune时发现pytorch把许多层都集合在一个sequential里,但是我们希望能把中间层的结果引出来做下一步操作,于是我自己琢磨了一个方法,以vgg为例,有点僵硬哈! 首先pytorch自带的vgg16模型的网络结构如下: VGG( (features): Sequential( (0): Conv2d (3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace) (2): Co

  • pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法

    如下所示: #获取模型权重 for k, v in model_2.state_dict().iteritems(): print("Layer {}".format(k)) print(v) #获取模型权重 for layer in model_2.modules(): if isinstance(layer, nn.Linear): print(layer.weight) #将一个模型权重载入另一个模型 model = VGG(make_layers(cfg['E']), **kw

  • 分享Pytorch获取中间层输出的3种方法

    目录 [1]方法一:获取nn.Sequential的中间层输出 [2]方法二:IntermediateLayerGetter [3]方法三:钩子 [1]方法一:获取nn.Sequential的中间层输出 import torch import torch.nn as nn model = nn.Sequential(             nn.Conv2d(3, 9, 1, 1, 0, bias=False),             nn.BatchNorm2d(9),          

  • PHP获取对象属性的三种方法实例分析

    本文实例讲述了PHP获取对象属性的三种方法.分享给大家供大家参考,具体如下: 今天查看yii源码,发现yii\base\Model中的attribute()方法是通过反射获取对象的public non-static属性.记得以前看到的代码都是用get_object_vars()这个函数获取的,昨天查看php文档,发现还可以用foreach遍历对象属性.于是写个例子练习下. class TestClass { private $a; protected $b; public $c; public

  • asp.net实现图片以二进制流输出的两种方法

    本文实例讲述了asp.net实现图片以二进制流输出的两种方法.分享给大家供大家参考,具体如下: 方法一: System.IO.MemoryStream ms = new System.IO.MemoryStream(); System.IO.Stream str = new FileUpload().PostedFile.InputStream; System.Drawing.Bitmap map = new System.Drawing.Bitmap(str); map.Save(ms, Sy

  • JS获取地址栏参数的两种方法(简单实用)

    js获取地址栏参数的方法有两种:第一种,采用正则表达式获取地址栏参数,第二种,是比较传统的方法,在这小编给大家强烈推荐使用第一种方法,既方便有实用,具体实现过程请看下文详述. 方法一:采用正则表达式获取地址栏参数:( 强烈推荐,既实用又方便!) function GetQueryString(name) { var reg = new RegExp("(^|&)"+ name +"=([^&]*)(&|$)"); var r = window

  • python实现中文输出的两种方法

    本文实例讲述了python实现中文输出的两种方法.分享给大家供大家参考.具体如下: 方法一: 用encode和decode 如: import os.path import xlrd,sys Filename='/home/tom/Desktop/1234.xls' if not os.path.isfile(Filename): raise NameError,"%s is not a valid filename"%Filename bk=xlrd.open_workbook(Fi

  • 对python-3-print重定向输出的几种方法总结

    方法1: import sys f=open('test.txt','a+') a='123' b='456' print >> f,a,b f.close() 方法2: import sys f=open('a.txt','w') old=sys.stdout #将当前系统输出储存到临时变量 sys.stdout=f #输出重定向到文件 print 'Hello World!' #测试一个打印输出 sys.stdout=old #还原系统输出 f.close() print open('a.

  • 微信小程序获取用户信息的两种方法wx.getUserInfo与open-data实例分析

    本文实例讲述了微信小程序获取用户信息的两种方法wx.getUserInfo与open-data.分享给大家供大家参考,具体如下: 在此之前,小程序获取微信的头像,昵称之类的用户信息,我用的都是wx.getUserInfo,例如: onLoad: function (options) { var that = this; //获取用户信息 wx.getUserInfo({ success: function (res) { console.log(res); that.data.userInfo

  • PHP实现获取ip地址的5种方法,以及插入用户登录日志操作示例

    本文实例讲述了PHP实现获取ip地址的5种方法,以及插入用户登录日志操作.分享给大家供大家参考,具体如下: php 获取ip地址的5种方法,插入用户登录日志实例,推荐使用第二种方法 <?php //方法1: $ip = $_SERVER["REMOTE_ADDR"]; echo $ip; //方法2: $ip = ($_SERVER["HTTP_VIA"]) ? $_SERVER["HTTP_X_FORWARDED_FOR"] : $_SE

  • Java中获取键盘输入值的三种方法介绍

    程序开发过程中,需要从键盘获取输入值是常有的事,但Java它偏偏就没有像c语言给我们提供的scanf(),C++给我们提供的cin()获取键盘输入值的现成函数!Java没有提供这样的函数也不代表遇到这种情况我们就束手无策,请你看以下三种解决方法吧: 以下将列出几种方法: 方法一:从控制台接收一个字符,然后将其打印出来 public static void main(String [] args) throws IOException{ System.out.print("Enter a char

  • Python爬虫后获取重定向url的两种方法

    下面给大家分享Python爬虫后获取重定向url的两种方法,具体内容如下所示: 方法(一) # 获得重定向url from urllib import request # https://zhidao.baidu.com/question/681501874175782812.html url = "https://www.baidu.com/link?url=IscBx0u8h9q4Uq3ihTs_PqnoNWe7slVWAd2dowQKrnqJedvthb3zrh9JqcMJu3ZqFrbW

随机推荐