pytorch hook 钩子函数的用法

钩子编程(hooking),也称作“挂钩”,是计算机程序设计术语,指通过拦截软件模块间的函数调用、消息传递、事件传递来修改或扩展操作系统、应用程序或其他软件组件的行为的各种技术。处理被拦截的函数调用、事件、消息的代码,被称为钩子(hook)。

Hook 是 PyTorch 中一个十分有用的特性。利用它,我们可以不必改变网络输入输出的结构,方便地获取、改变网络中间层变量的值和梯度。这个功能被广泛用于可视化神经网络中间层的 featuregradient,从而诊断神经网络中可能出现的问题,分析网络有效性。

本文主要用 hook 函数输出网络执行过程中 forward 和 backward 的执行顺序,以此找到了bug所在。

用法如下:

# 设置hook func
def hook_func(name, module):
    def hook_function(module, inputs, outputs):
        # 请依据使用场景自定义函数
        print(name+' inputs', inputs)
        print(name+' outputs', outputs)
    return hook_function

# 注册正反向hook
for name, module in model.named_modules():
    module.register_forward_hook(hook_func('[forward]: '+name, module))
    module.register_backward_hook(hook_func('[backward]: '+name, module))

如一个简单的 MNIST 手写数字识别的模型结构如下:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

打印模型:

Net(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (dropout1): Dropout(p=0.25, inplace=False)
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=9216, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

构建hook函数:

# 设置hook func
def hook_func(name, module):
    def hook_function(module, inputs, outputs):
        with open("log_model.txt", 'a+') as f:
            # 请依据使用场景自定义函数
            f.write(name + '   len(inputs): ' + str(len(inputs)) + '\n')
            f.write(name + '   len(outputs):  ' + str(len(outputs)) + '\n')
    return hook_function

# 注册正反向hook
for name, module in model.named_modules():
    module.register_forward_hook(hook_func('[forward]: '+name, module))
    module.register_backward_hook(hook_func('[backward]: '+name, module))

输出的前向和反向传播过程:

[forward]: conv1   len(inputs): 1
[forward]: conv1   len(outputs):  8
[forward]: conv2   len(inputs): 1
[forward]: conv2   len(outputs):  8
[forward]: dropout1   len(inputs): 1
[forward]: dropout1   len(outputs):  8
[forward]: fc1   len(inputs): 1
[forward]: fc1   len(outputs):  8
[forward]: dropout2   len(inputs): 1
[forward]: dropout2   len(outputs):  8
[forward]: fc2   len(inputs): 1
[forward]: fc2   len(outputs):  8
[forward]:    len(inputs): 1
[forward]:    len(outputs):  8
[backward]:    len(inputs): 2
[backward]:    len(outputs):  1
[backward]: fc2   len(inputs): 3
[backward]: fc2   len(outputs):  1
[backward]: dropout2   len(inputs): 1
[backward]: dropout2   len(outputs):  1
[backward]: fc1   len(inputs): 3
[backward]: fc1   len(outputs):  1
[backward]: dropout1   len(inputs): 1
[backward]: dropout1   len(outputs):  1
[backward]: conv2   len(inputs): 2
[backward]: conv2   len(outputs):  1
[backward]: conv1   len(inputs): 2
[backward]: conv1   len(outputs):  1

因为只要模型处于train状态,hook_func 就会执行,导致不断输出 [forward] 和 [backward],所以将输出内容建议写到文件中,而不是 print

到此这篇关于pytorch hook 钩子函数的用法的文章就介绍到这了,更多相关pytorch hook 钩子函数内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • 5 分钟读懂Python 中的 Hook 钩子函数

    1. 什么是Hook 经常会听到钩子函数(hook function)这个概念,最近在看目标检测开源框架mmdetection,里面也出现大量Hook的编程方式,那到底什么是hook?hook的作用是什么? what is hook ?钩子hook,顾名思义,可以理解是一个挂钩,作用是有需要的时候挂一个东西上去.具体的解释是:钩子函数是把我们自己实现的hook函数在某一时刻挂接到目标挂载点上. hook函数的作用 举个例子,hook的概念在windows桌面软件开发很常见,特别是各种事件触发的机

  • python hook监听事件详解

    本文实例为大家分享了python hook监听事件的具体代码,供大家参考,具体内容如下 # -*- coding: utf-8 -*- # # by oldj http://oldj.net/ # import pythoncom import pyHook def onMouseEvent(event): # 监听鼠标事件 print "MessageName:",event.MessageName print "Message:", event.Message

  • python利用hook技术破解https的实例代码

    相对于http协议,http是的特点就是他的安全性,http协议的通信内容用普通的嗅探器可以捕捉到,但是https协议的内容嗅探到的是加密后的内容,对我们的利用价值不是很高,所以一些大的网站----涉及到"大米"的网站,采用的都是http是协议,嘿嘿,即便这样,还是有办法能看到他的用户名和密码的,嘿嘿,本文只是用于技术学习,只是和大家交流技术,希望不要用于做违法的事情,这个例子是在firefox浏览器下登录https协议的网站,我们预先打开程序,就来了个捕获用户名和密码: 下面是源代码

  • 详解Python开发中如何使用Hook技巧

    什么是Hook,就是在一个已有的方法上加入一些钩子,使得在该方法执行前或执行后另在做一些额外的处理,那么Hook技巧有什么作用以及我们为什么需要使用它呢,事实上如果一个项目在设计架构时考虑的足够充分,模块抽象的足够合理,设计之初为以后的扩展预留了足够的接口,那么我们完全可以不需要Hook技巧.但恰恰架构人员在项目设计之初往往没办法想的足够的深远,使得后续在扩展时深圳面临重构的痛苦,这时Hook技巧似乎可以为我们带来一记缓兵之计,通过对旧的架构进行加钩子来满足新的扩展需求. 下面我们就来看看如果进

  • pytorch hook 钩子函数的用法

    钩子编程(hooking),也称作“挂钩”,是计算机程序设计术语,指通过拦截软件模块间的函数调用.消息传递.事件传递来修改或扩展操作系统.应用程序或其他软件组件的行为的各种技术.处理被拦截的函数调用.事件.消息的代码,被称为钩子(hook). Hook 是 PyTorch 中一个十分有用的特性.利用它,我们可以不必改变网络输入输出的结构,方便地获取.改变网络中间层变量的值和梯度.这个功能被广泛用于可视化神经网络中间层的 feature.gradient,从而诊断神经网络中可能出现的问题,分析网络

  • PyTorch中topk函数的用法详解

    听名字就知道这个函数是用来求tensor中某个dim的前k大或者前k小的值以及对应的index. 用法 torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor) input:一个tensor数据 k:指明是得到前k个数据以及其index dim: 指定在哪个维度上排序, 默认是最后一个维度 largest:如果为True,按照大到小排序: 如果为False,按照小到大排序

  • Pytorch上下采样函数--interpolate用法

    最近用到了上采样下采样操作,pytorch中使用interpolate可以很轻松的完成 def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None): r""" 根据给定 size 或 scale_factor,上采样或下采样输入数据input. 当前支持 temporal, spatial 和 volumetric 输入数据的上采样,其shape 分别为:3-

  • vue created钩子函数与mounted钩子函数的用法区别

    1:在使用vue框架的过程中,我们经常需要给一些数据做一些初始化处理,这时候我们常用的就是在created与mounted选项中作出处理. 首先来看下官方解释,官方解释说created是在实例创建完成后呗立即调用. 在这一步,实例已完成以下配置:数据观测 (data observer),属性和方法的运算,watch/event 事件回调.然而,挂载阶段还没开始,$el 属性目前不可见. 这话的意思我觉得重点在于说挂架阶段还没开始,什么叫还没开始挂载,也就是说,模板还没有被渲染成html,也就是这

  • react新版本生命周期钩子函数及用法详解

    和旧的生命周期相比 准备废弃三个钩子,已经新增了两个钩子 React16 之后有三个生命周期被废弃(但并没有删除) componentWillMount( 组件将要挂载的钩子) componentWillReceiveProps(组件将要接收一个新的参数时的钩子) componentWillUpdate(组件将要更新的钩子) 新版本的生命周期新增的钩子 getDerivedStateFromProps 通过参数可以获取新的属性和状态 该函数是静态的 该函数的返回值会覆盖掉组件状态 getSnap

  • Vuerouter的beforeEach与afterEach钩子函数的区别

    在路由跳转的时候,我们需要一些权限判断或者其他操作.这个时候就需要使用路由的钩子函数. 定义:路由钩子主要是给使用者在路由发生变化时进行一些特殊的处理而定义的函数. 总体来讲vue里面提供了三大类钩子,两种函数 1.全局钩子 2.某个路由的钩子 3.组件内钩子 两种函数: 1.Vue.beforeEach(function(to,form,next){}) /*在跳转之前执行*/ 2.Vue.afterEach(function(to,form))/*在跳转之后判断*/ 全局钩子函数 顾名思义,

  • 聊聊vue生命周期钩子函数有哪些,分别什么时候触发

    目录 vue生命周期钩子函数 以下为详解版 生命周期mounted和activated使用.踩坑 activated mounted 踩坑 vue生命周期钩子函数 vue生命周期即为一个组件从出生到死亡的一个完整周期 主要包括以下4个阶段:创建,挂载,更新,销毁 创建前:beforeCreate, 创建后:created 挂载前:beforeMount, 挂载后:mounted 更新前:beforeUpdate, 更新后:updated 销毁前:beforeDestroy, 销毁后:destro

  • Flask框架钩子函数功能与用法分析

    本文实例讲述了Flask框架钩子函数功能与用法.分享给大家供大家参考,具体如下: 在Flask中钩子函数是使用特定的装饰器的函数.为什么叫做钩子函数呢,是因为钩子函数可以在正常执行的代码中,插入一段自己想要执行的代码,那么这种函数就叫做钩子函数. before_first_request:Flask项目第一次部署后会执行的钩子函数. before_request:请求已经到达了Flask,但是还没有进入到具体的视图函数之前调用.一般这个就是在函数之前,我们可以把一些后面需要用到的数据先处理好,方

  • vue-router的钩子函数用法实例分析

    本文实例讲述了vue-router的钩子函数用法.分享给大家供大家参考,具体如下: vue路由钩子大致可以分为三类: 1.全局钩子 主要包括beforeEach和aftrEach, beforeEach函数有三个参数: to:router即将进入的路由对象 from:当前导航即将离开的路由 next:Function,进行管道中的一个钩子,如果执行完了,则导航的状态就是 confirmed (确认的):否则为false,终止导航. afterEach函数不用传next()函数 这类钩子主要作用于

随机推荐