pytorch __init__、forward与__call__的用法小结

1.介绍

当我们使用pytorch来构建网络框架的时候,也会遇到和tensorflow(tensorflow __init__、build 和call小结)类似的情况,即经常会遇到__init__、forward和call这三个互相搭配着使用,那么它们的主要区别又在哪里呢?

1)__init__主要用来做参数初始化用,比如我们要初始化卷积的一些参数,就可以放到这里面,这点和tf里面的用法是一样的

2)forward是表示一个前向传播,构建网络层的先后运算步骤

3)__call__的功能其实和forward类似,所以很多时候,我们构建网络的时候,可以用__call__替代forward函数,但它们两个的区别又在哪里呢?

当网络构建完之后,调__call__的时候,会去先调forward,即__call__其实是包了一层forward,所以会导致两者的功能类似。

在pytorch在nn.Module中,实现了__call__方法,而在__call__方法中调用了forward函数:

https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py

2.代码

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
 def __init__(self, in_channels, mid_channels, out_channels):
 super(Net, self).__init__()
 self.conv0 = torch.nn.Sequential(
 torch.nn.Conv2d(in_channels, mid_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 torch.nn.LeakyReLU())
 self.conv1 = torch.nn.Sequential(
 torch.nn.Conv2d(mid_channels, out_channels * 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))

 def forward(self, x):
 x = self.conv0(x)
 x = self.conv1(x)
 return x

class Net(nn.Module):
 def __init__(self, in_channels, mid_channels, out_channels):
 super(Net, self).__init__()
 self.conv0 = torch.nn.Sequential(
 torch.nn.Conv2d(in_channels, mid_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 torch.nn.LeakyReLU())
 self.conv1 = torch.nn.Sequential(
 torch.nn.Conv2d(mid_channels, out_channels * 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))

 def __call__(self, x):
 x = self.conv0(x)
 x = self.conv1(x)
 return x

补充:torch/nn目录结构以及__init__.py

torch/nn目录结构以及init.py

torch/nn目录结构

__init__.py:

from .modules import *
#nn.modules  导入modules目录下内容 定义容器modules
from .parameter import Parameter
#nn.Parameter 导入parameter.py  定义parameter
from .parallel import DataParallel
#导入parallel目录下data_parallel.py中的DataParallel类
from . import init
#nn.init   导入init.py   参数初始化
from . import utils
#nn.utils  导入utils目录下内容 官网api下nn.utils下api

对于backends, functional.py, _functions 需要在代码前重新Import

例如我们常用的

import torch.nn.functional as F 就是导入了functional.py

backends和_functions是functional.py实现各种函数时所用到的。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。如有错误或未考虑完全的地方,望不吝赐教。

(0)

相关推荐

  • python中的__init__ 、__new__、__call__小结

    1.__new__(cls, *args, **kwargs)  创建对象时调用,返回当前对象的一个实例;注意:这里的第一个参数是cls即class本身2.__init__(self, *args, **kwargs) 创建完对象后调用,对当前对象的实例的一些初始化,无返回值,即在调用__new__之后,根据返回的实例初始化:注意,这里的第一个参数是self即对象本身[注意和new的区别]3.__call__(self,  *args, **kwargs) 如果类实现了这个方法,相当于把这个类型

  • 浅谈python中的__init__、__new__和__call__方法

    前言 本文主要给大家介绍关于python中__init__.__new__和__call__方法的相关内容,分享出来供大家参考学习,下面话不多说,来一起看看详细的介绍: 任何事物都有一个从创建,被使用,再到消亡的过程,在程序语言面向对象编程模型中,对象也有相似的命运:创建.初始化.使用.垃圾回收,不同的阶段由不同的方法(角色)负责执行. 定义一个类时,大家用得最多的就是 __init__ 方法,而 __new__ 和 __call__ 使用得比较少,这篇文章试图帮助大家把这3个方法的正确使用方式

  • 基于tensorflow __init__、build 和call的使用小结

    1.介绍 在使用tf构建网络框架的时候,经常会遇到__init__.build 和call这三个互相搭配着使用,那么它们的区别主要在哪里呢? 1)__init__主要用来做参数初始化用,比如我们要初始化卷积的一些参数,就可以放到这里面 2)call可以把类型的对象当做函数来使用,这个对象可以是在__init__里面也可以是在build里面 3)build一般是和call搭配使用,这个时候,它的功能和__init__很相似,当build中存放本层需要初始化的变量,当call被第一次调用的时候,会先

  • 详解Python中的__new__、__init__、__call__三个特殊方法

    __new__: 对象的创建,是一个静态方法,第一个参数是cls.(想想也是,不可能是self,对象还没创建,哪来的self) __init__ : 对象的初始化, 是一个实例方法,第一个参数是self. __call__ : 对象可call,注意不是类,是对象. 先有创建,才有初始化.即先__new__,而后__init__. 上面说的不好理解,看例子. 1.对于__new__ class Bar(object): pass class Foo(object): def __new__(cls

  • pytorch __init__、forward与__call__的用法小结

    1.介绍 当我们使用pytorch来构建网络框架的时候,也会遇到和tensorflow(tensorflow __init__.build 和call小结)类似的情况,即经常会遇到__init__.forward和call这三个互相搭配着使用,那么它们的主要区别又在哪里呢? 1)__init__主要用来做参数初始化用,比如我们要初始化卷积的一些参数,就可以放到这里面,这点和tf里面的用法是一样的 2)forward是表示一个前向传播,构建网络层的先后运算步骤 3)__call__的功能其实和fo

  • Pytorch 卷积中的 Input Shape用法

    先看Pytorch中的卷积 class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True) 二维卷积层, 输入的尺度是(N, C_in,H,W),输出尺度(N,C_out,H_out,W_out)的计算方式 这里比较奇怪的是这个卷积层居然没有定义input shape,输入尺寸明明是:(N, C_in, H,W),但是定义中却只需

  • C++语言中std::array的用法小结(神器用法)

    摘要:在这篇文章里,将从各个角度介绍下std::array的用法,希望能带来一些启发. td::array是在C++11标准中增加的STL容器,它的设计目的是提供与原生数组类似的功能与性能.也正因此,使得std::array有很多与其他容器不同的特殊之处,比如:std::array的元素是直接存放在实例内部,而不是在堆上分配空间:std::array的大小必须在编译期确定:std::array的构造函数.析构函数和赋值操作符都是编译器隐式声明的--这让很多用惯了std::vector这类容器的程

  • Java中String.split()用法小结

    在java.lang包中有String.split()方法,返回是一个数组 我在应用中用到一些,给大家总结一下,仅供大家参考: 1.如果用"."作为分隔的话,必须是如下写法,String.split("\\."),这样才能正确的分隔开,不能用String.split("."); 2.如果用"|"作为分隔的话,必须是如下写法,String.split("\\|"),这样才能正确的分隔开,不能用String.s

  • mybatis 中 foreach collection的用法小结(三种)

    foreach的主要用在构建in条件中,它可以在SQL语句中进行迭代一个集合. foreach元素的属性主要有 item,index,collection,open,separator,close. item表示集合中每一个元素进行迭代时的别名,     index指 定一个名字,用于表示在迭代过程中,每次迭代到的位置,     open表示该语句以什么开始,     separator表示在每次进行迭代之间以什么符号作为分隔 符,     close表示以什么结束. 在使用foreach的时候

  • JS产生随机数的用法小结

    代码如下所述: <script> function GetRandomNum(Min,Max) { var Range = Max - Min; var Rand = Math.random(); return(Min + Math.round(Rand * Range)); } var num = GetRandomNum(1,10); alert(num); </script> var chars = ['0','1','2','3','4','5','6','7','8','

  • 详解PHP中cookie和session的区别及cookie和session用法小结

    具体来说 cookie 是保存在"客户端"的,而session是保存在"服务端"的 cookie 是通过扩展http协议实现的 cookie 主要包括 :名字,值,过期时间,路径和域: 如果cookie不设置生命周期,则以浏览器关闭而关闭,这种cookie一般存储在内存而不是硬盘上.若设置了生命周期则相反,不随浏览器的关闭而消失,这些cookie仍然有效直到超过设定的过 期 时间. session 一种类似散列表的形式保存信息, 当程序需要为某个客户端的请求创建一个

  • MySql数据库中Select用法小结

    一.条件筛选 1.数字筛选:sql = "Select * from [sheet1$] Where 销售单价 > 100" 2.字符条件:sql = "Select * from [sheet1$] Where 物品名称 ='挡泥板'" 3.日期条件:sql = "Select * from [sheet1$] Where 物品名称 ='挡泥板'" 4.区间条件:sql = "Select * from [sheet1$] Wh

  • javaScript产生随机数的用法小结

    var chars = ['0','1','2','3','4','5','6','7','8','9','A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z']; function generateMixed(n) { var res = ""; for(var i = 0; i < n ; i ++) { var id = M

  • 基于pytorch 预训练的词向量用法详解

    如何在pytorch中使用word2vec训练好的词向量 torch.nn.Embedding() 这个方法是在pytorch中将词向量和词对应起来的一个方法. 一般情况下,如果我们直接使用下面的这种: self.embedding = torch.nn.Embedding(num_embeddings=vocab_size, embedding_dim=embeding_dim) num_embeddings=vocab_size 表示词汇量的大小 embedding_dim=embeding

随机推荐