pytorch中的hook机制register_forward_hook

目录
  • 1、hook背景
  • 2、源码阅读
  • 3、定义一个用于测试hooker的类
  • 4、定义hook函数
  • 5、对需要的层注册hook
  • 6、测试forward()返回的特征和hook记录的是否一致
    • 6.1测试forward()提供的输入输出特征
    • 6.2hook记录的输入特征和输出特征
    • 6.3把hook记录的和forward做减法
  • 7、完整代码

1、hook背景

Hook被成为钩子机制,这不是pytorch的首创,在Windows的编程中已经被普遍采用,包括进程内钩子和全局钩子。按照自己的理解,hook的作用是通过系统来维护一个链表,使得用户拦截(获取)通信消息,用于处理事件。

pytorch中包含forwardbackward两个钩子注册函数,用于获取forward和backward中输入和输出,按照自己不全面的理解,应该目的是“不改变网络的定义代码,也不需要在forward函数中return某个感兴趣层的输出,这样代码太冗杂了”。

2、源码阅读

register_forward_hook()函数必须在forward()函数调用之前被使用,因为这个函数源码注释显示这个函数“ it will not have effect on forward since this is called after :func:`forward` is called”,也就是这个函数在forward()之后就没有作用了!!!):

作用:获取forward过程中每层的输入和输出,用于对比hook是不是正确记录。

def register_forward_hook(self, hook):
        r"""Registers a forward hook on the module.
        The hook will be called every time after :func:`forward` has computed an output.
        It should have the following signature::
            hook(module, input, output) -> None or modified output
        The hook can modify the output. It can modify the input inplace but
        it will not have effect on forward since this is called after
        :func:`forward` is called.

        Returns:
            :class:`torch.utils.hooks.RemovableHandle`:
                a handle that can be used to remove the added hook by calling
                ``handle.remove()``
        """
        handle = hooks.RemovableHandle(self._forward_hooks)
        self._forward_hooks[handle.id] = hook
        return handle

3、定义一个用于测试hooker的类

如果随机的初始化每个层,那么就无法测试出自己获取的输入输出是不是forward中的输入输出了,所以需要将每一层的权重和偏置设置为可识别的值(比如全部初始化为1)。网络包含两层(Linear有需要求导的参数被称为一个层,而ReLU没有需要求导的参数不被称作一层),__init__()中调用initialize函数对所有层进行初始化。

注意:在forward()函数返回各个层的输出,但是ReLU6没有返回,因为后续测试的时候不对这一层进行注册hook。

class TestForHook(nn.Module):
    def __init__(self):
        super().__init__()

        self.linear_1 = nn.Linear(in_features=2, out_features=2)
        self.linear_2 = nn.Linear(in_features=2, out_features=1)
        self.relu = nn.ReLU()
        self.relu6 = nn.ReLU6()
        self.initialize()

    def forward(self, x):
        linear_1 = self.linear_1(x)
        linear_2 = self.linear_2(linear_1)
        relu = self.relu(linear_2)
        relu_6 = self.relu6(relu)
        layers_in = (x, linear_1, linear_2)
        layers_out = (linear_1, linear_2, relu)
        return relu_6, layers_in, layers_out
    def initialize(self):
        """ 定义特殊的初始化,用于验证是不是获取了权重"""
        self.linear_1.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1], [1, 1]]))
        self.linear_1.bias = torch.nn.Parameter(torch.FloatTensor([1, 1]))
        self.linear_2.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1]]))
        self.linear_2.bias = torch.nn.Parameter(torch.FloatTensor([1]))
        return True

4、定义hook函数

hook()函数是register_forward_hook()函数必须提供的参数,好处是“用户可以自行决定拦截了中间信息之后要做什么!”,比如自己想单纯的记录网络的输入输出(也可以进行修改等更加复杂的操作)。

首先定义几个容器用于记录:

定义用于获取网络各层输入输出tensor的容器:

# 并定义module_name用于记录相应的module名字
module_name = []
features_in_hook = []
features_out_hook = []
hook函数需要三个参数,这三个参数是系统传给hook函数的,自己不能修改这三个参数:

hook函数负责将获取的输入输出添加到feature列表中;并提供相应的module名字

def hook(module, fea_in, fea_out):
    print("hooker working")
    module_name.append(module.__class__)
    features_in_hook.append(fea_in)
    features_out_hook.append(fea_out)
    return None

5、对需要的层注册hook

注册钩子必须在forward()函数被执行之前,也就是定义网络进行计算之前就要注册,下面的代码对网络除去ReLU6以外的层都进行了注册(也可以选定某些层进行注册):

注册钩子可以对某些层单独进行:

net = TestForHook()
net_chilren = net.children()
for child in net_chilren:
    if not isinstance(child, nn.ReLU6):
        child.register_forward_hook(hook=hook)

6、测试forward()返回的特征和hook记录的是否一致

6.1 测试forward()提供的输入输出特征

由于前面的forward()函数返回了需要记录的特征,这里可以直接测试:

out, features_in_forward, features_out_forward = net(x)
print("*"*5+"forward return features"+"*"*5)
print(features_in_forward)
print(features_out_forward)
print("*"*5+"forward return features"+"*"*5)

得到下面的输出是理所当然的:

*****forward return features*****
(tensor([[0.1000, 0.1000],
        [0.1000, 0.1000]]), tensor([[1.2000, 1.2000],
        [1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
        [3.4000]], grad_fn=<AddmmBackward>))
(tensor([[1.2000, 1.2000],
        [1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
        [3.4000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
        [3.4000]], grad_fn=<ThresholdBackward0>))
*****forward return features*****

6.2 hook记录的输入特征和输出特征

hook通过list结构进行记录,所以可以直接print

测试features_in是不是存储了输入:

print("*"*5+"hook record features"+"*"*5)
print(features_in_hook)
print(features_out_hook)
print(module_name)
print("*"*5+"hook record features"+"*"*5)

得到和forward一样的结果:

*****hook record features*****
[(tensor([[0.1000, 0.1000],
        [0.1000, 0.1000]]),), (tensor([[1.2000, 1.2000],
        [1.2000, 1.2000]], grad_fn=<AddmmBackward>),), (tensor([[3.4000],
        [3.4000]], grad_fn=<AddmmBackward>),)]
[tensor([[1.2000, 1.2000],
        [1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
        [3.4000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
        [3.4000]], grad_fn=<ThresholdBackward0>)]
[<class 'torch.nn.modules.linear.Linear'>, 
<class 'torch.nn.modules.linear.Linear'>,
 <class 'torch.nn.modules.activation.ReLU'>]
*****hook record features*****

6.3 把hook记录的和forward做减法

如果害怕会有小数点后面的数值不一致,或者数据类型的不匹配,可以对hook记录的特征和forward记录的特征做减法:

测试forward返回的feautes_in是不是和hook记录的一致:

print("sub result'")
for forward_return, hook_record in zip(features_in_forward, features_in_hook):
    print(forward_return-hook_record[0])

得到的全部都是0,说明hook没问题:

sub result
tensor([[0., 0.],
        [0., 0.]])
tensor([[0., 0.],
        [0., 0.]], grad_fn=<SubBackward0>)
tensor([[0.],
        [0.]], grad_fn=<SubBackward0>)

7、完整代码

import torch
import torch.nn as nn

class TestForHook(nn.Module):
    def __init__(self):
        super().__init__()

        self.linear_1 = nn.Linear(in_features=2, out_features=2)
        self.linear_2 = nn.Linear(in_features=2, out_features=1)
        self.relu = nn.ReLU()
        self.relu6 = nn.ReLU6()
        self.initialize()

    def forward(self, x):
        linear_1 = self.linear_1(x)
        linear_2 = self.linear_2(linear_1)
        relu = self.relu(linear_2)
        relu_6 = self.relu6(relu)
        layers_in = (x, linear_1, linear_2)
        layers_out = (linear_1, linear_2, relu)
        return relu_6, layers_in, layers_out

    def initialize(self):
        """ 定义特殊的初始化,用于验证是不是获取了权重"""
        self.linear_1.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1], [1, 1]]))
        self.linear_1.bias = torch.nn.Parameter(torch.FloatTensor([1, 1]))
        self.linear_2.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1]]))
        self.linear_2.bias = torch.nn.Parameter(torch.FloatTensor([1]))
        return True

定义用于获取网络各层输入输出tensor的容器,并定义module_name用于记录相应的module名字

module_name = []
features_in_hook = []
features_out_hook = []

hook函数负责将获取的输入输出添加到feature列表中,并提供相应的module名字

def hook(module, fea_in, fea_out):
    print("hooker working")
    module_name.append(module.__class__)
    features_in_hook.append(fea_in)
    features_out_hook.append(fea_out)
    return None

定义全部是1的输入:

x = torch.FloatTensor([[0.1, 0.1], [0.1, 0.1]])

注册钩子可以对某些层单独进行:

net = TestForHook()
net_chilren = net.children()
for child in net_chilren:
    if not isinstance(child, nn.ReLU6):
        child.register_forward_hook(hook=hook)

测试网络输出:

out, features_in_forward, features_out_forward = net(x)
print("*"*5+"forward return features"+"*"*5)
print(features_in_forward)
print(features_out_forward)
print("*"*5+"forward return features"+"*"*5)

测试features_in是不是存储了输入:

print("*"*5+"hook record features"+"*"*5)
print(features_in_hook)
print(features_out_hook)
print(module_name)
print("*"*5+"hook record features"+"*"*5)

测试forward返回的feautes_in是不是和hook记录的一致:

print("sub result")
for forward_return, hook_record in zip(features_in_forward, features_in_hook):
    print(forward_return-hook_record[0])

到此这篇关于pytorch中的hook机制register_forward_hook的文章就介绍到这了,更多相关pytorch中的hook机制内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法

    如下所示: #获取模型权重 for k, v in model_2.state_dict().iteritems(): print("Layer {}".format(k)) print(v) #获取模型权重 for layer in model_2.modules(): if isinstance(layer, nn.Linear): print(layer.weight) #将一个模型权重载入另一个模型 model = VGG(make_layers(cfg['E']), **kw

  • pytorch中的hook机制register_forward_hook

    目录 1.hook背景 2.源码阅读 3.定义一个用于测试hooker的类 4.定义hook函数 5.对需要的层注册hook 6.测试forward()返回的特征和hook记录的是否一致 6.1测试forward()提供的输入输出特征 6.2hook记录的输入特征和输出特征 6.3把hook记录的和forward做减法 7.完整代码 1.hook背景 Hook被成为钩子机制,这不是pytorch的首创,在Windows的编程中已经被普遍采用,包括进程内钩子和全局钩子.按照自己的理解,hook的作

  • 浅谈PHP中如何实现Hook机制

    对"钩子"这个概念其实不熟悉,最近看到一个php框架中用到这种机制来扩展项目,所以大概来了解下. 所谓Hook机制,是从Windows编程中流行开的一种技术.其主要思想是提前在可能增加功能的地方埋好(预设)一个钩子,这个钩子并没有实际的意义,当我们需要重新修改或者增加这个地方的逻辑的时候,把扩展的类或者方法挂载到这个点即可. hook插件机制的基本思想: 在项目代码中,你认为要扩展(暂时不扩展)的地方放置一个钩子函数,等需要扩展的时候,把需要实现的类和函数挂载到这个钩子上,就可以实现扩

  • Pytorch中的自动求梯度机制和Variable类实例

    自动求导机制是每一个深度学习框架中重要的性质,免去了手动计算导数,下面用代码介绍并举例说明Pytorch的自动求导机制. 首先介绍Variable,Variable是对Tensor的一个封装,操作和Tensor是一样的,但是每个Variable都有三个属性:Varibale的Tensor本身的.data,对应Tensor的梯度.grad,以及这个Variable是通过什么方式得到的.grad_fn,根据最新消息,在pytorch0.4更新后,torch和torch.autograd.Variab

  • pytorch中获取模型input/output shape实例

    Pytorch官方目前无法像tensorflow, caffe那样直接给出shape信息,详见 https://github.com/pytorch/pytorch/pull/3043 以下代码算一种workaround.由于CNN, RNN等模块实现不一样,添加其他模块支持可能需要改代码. 例如RNN中bias是bool类型,其权重也不是存于weight属性中,不过我们只关注shape够用了. 该方法必须构造一个输入调用forward后(model(x)调用)才可获取shape #coding

  • pytorch 中的重要模块化接口nn.Module的使用

    torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络 nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法 查看源码 初始化部分: def __init__(self): self._backend = thnn_backend self._parameters = OrderedDict() self._buffers = OrderedDict() self._backward_hooks = Ordered

  • 5 分钟读懂Python 中的 Hook 钩子函数

    1. 什么是Hook 经常会听到钩子函数(hook function)这个概念,最近在看目标检测开源框架mmdetection,里面也出现大量Hook的编程方式,那到底什么是hook?hook的作用是什么? what is hook ?钩子hook,顾名思义,可以理解是一个挂钩,作用是有需要的时候挂一个东西上去.具体的解释是:钩子函数是把我们自己实现的hook函数在某一时刻挂接到目标挂载点上. hook函数的作用 举个例子,hook的概念在windows桌面软件开发很常见,特别是各种事件触发的机

  • PHP中的插件机制原理和实例

    PHP项目中很多用到插件的地方,更尤其是基础程序写成之后很多功能由第三方完善开发的时候,更能用到插件机制,现在说一下插件的实现.特点是无论你是否激活,都不影响主程序的运行,即使是删除也不会影响. 从一个插件安装到运行过程的角度来说,主要是三个步骤: 1.插件安装(把插件信息收集进行采集和记忆的过程,比如放到数据库中或者XML中) 2.插件激活(打开插件,让监听插件的地方开始进行调用) 3.插件运行(插件功能的实现) 从一个插件的运行上来说主要以下几点: 1.插件的动态监听和加载(插件的信息获取)

  • Java开发之内部类对象的创建及hook机制分析

    本文实例讲述了Java内部类对象的创建及hook机制.分享给大家供大家参考,具体如下: Java中的内部类虽然在状态信息上与其外围类在状态信息是完全独立的(可直接通过内部类实例执行其功能),但是外围类对象却是内部类对象得以存在的基础. 内部类对象生成的时候,必须要保证其能够有外围类对象进行挂靠(hook),从而java提供了严格的内部类对象生成的语法. 一般惯用两种方法,生成内部类对象. Method1: 使用  外围类实例.new  内部类名称() 的标准方法. Example 1: publ

  • 关于PyTorch 自动求导机制详解

    自动求导机制 从后向中排除子图 每个变量都有两个标志:requires_grad和volatile.它们都允许从梯度计算中精细地排除子图,并可以提高效率. requires_grad 如果有一个单一的输入操作需要梯度,它的输出也需要梯度.相反,只有所有输入都不需要梯度,输出才不需要.如果其中所有的变量都不需要梯度进行,后向计算不会在子图中执行. >>> x = Variable(torch.randn(5, 5)) >>> y = Variable(torch.rand

  • 关于pytorch中网络loss传播和参数更新的理解

    相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56,但是pytorch的数量从87篇提升到了252篇. TensorFlow: 228--->266 Keras: 42--->56 Pytorch: 87--->252 在使用pytorch中,自己有一些思考,如下: 1. loss计算和反向传播 import torch.nn as nn

随机推荐