pytorch打印网络结构的实例

最简单的方法当然可以直接print(net),但是这样网络比较复杂的时候效果不太好,看着比较乱;以前使用caffe的时候有一个网站可以在线生成网络框图,tensorflow可以用tensor board,keras中可以用model.summary()、或者plot_model()。pytorch没有这样的API,但是可以用代码来完成。

(1)安装环境:graphviz

conda install -n pytorch python-graphviz

或:

sudo apt-get install graphviz

或者从官网下载,按此教程。

(2)生成网络结构的代码:

def make_dot(var, params=None):
  """ Produces Graphviz representation of PyTorch autograd graph
  Blue nodes are the Variables that require grad, orange are Tensors
  saved for backward in torch.autograd.Function
  Args:
    var: output Variable
    params: dict of (name, Variable) to add names to node that
      require grad (TODO: make optional)
  """
  if params is not None:
    assert isinstance(params.values()[0], Variable)
    param_map = {id(v): k for k, v in params.items()}

  node_attr = dict(style='filled',
           shape='box',
           align='left',
           fontsize='12',
           ranksep='0.1',
           height='0.2')
  dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
  seen = set()

  def size_to_str(size):
    return '('+(', ').join(['%d' % v for v in size])+')'
  def add_nodes(var):
    if var not in seen:
      if torch.is_tensor(var):
        dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
      elif hasattr(var, 'variable'):
        u = var.variable
        name = param_map[id(u)] if params is not None else ''
        node_name = '%s\n %s' % (name, size_to_str(u.size()))
        dot.node(str(id(var)), node_name, fillcolor='lightblue')
      else:
        dot.node(str(id(var)), str(type(var).__name__))
      seen.add(var)
      if hasattr(var, 'next_functions'):
        for u in var.next_functions:
          if u[0] is not None:
            dot.edge(str(id(u[0])), str(id(var)))
            add_nodes(u[0])
      if hasattr(var, 'saved_tensors'):
        for t in var.saved_tensors:
          dot.edge(str(id(t)), str(id(var)))
          add_nodes(t)
  add_nodes(var.grad_fn)
  return dot

(3)打印网络结构:

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph

class CNN(nn.module):
  def __init__(self):
   ******
   def forward(self,x):
   ******
   return out

*****************************
def make_dot(): #复制上面的代码
*****************************

if __name__ == '__main__':
  net = CNN()
  x = Variable(torch.randn(1, 1, 1024,1024))
  y = net(x)
  g = make_dot(y)
  g.view() 

  params = list(net.parameters())
  k = 0
  for i in params:
    l = 1
    print("该层的结构:" + str(list(i.size())))
    for j in i.size():
      l *= j
    print("该层参数和:" + str(l))
    k = k + l
  print("总参数数量和:" + str(k))

(4)结果展示(例如这是一个resnet block类型的网络):

以上这篇pytorch打印网络结构的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • 对pytorch网络层结构的数组化详解

    最近再写openpose,它的网络结构是多阶段的网络,所以写网络的时候很想用列表的方式,但是直接使用列表不能将网络中相应的部分放入到cuda中去. 其实这个问题很简单的,使用moduleList就好了. 1 我先是定义了一个函数,用来根据超参数,建立一个基础网络结构 stage = [[3, 3, 3, 1, 1], [7, 7, 7, 7, 7, 1, 1]] branches_cfg = [[[128, 128, 128, 512, 38], [128, 128, 128, 512, 19]

  • pytorch打印网络结构的实例

    最简单的方法当然可以直接print(net),但是这样网络比较复杂的时候效果不太好,看着比较乱:以前使用caffe的时候有一个网站可以在线生成网络框图,tensorflow可以用tensor board,keras中可以用model.summary().或者plot_model().pytorch没有这样的API,但是可以用代码来完成. (1)安装环境:graphviz conda install -n pytorch python-graphviz 或: sudo apt-get instal

  • Java编程实现从尾到头打印链表代码实例

    问题描述:输入一个链表的头结点,从尾巴到头反过来打印出每个结点的值. 首先定义链表结点 public class ListNode { int val; ListNode next = null; ListNode(int val){ this.val = val; } } 思路1:此题明显想到是利用栈的思想,后进先出,先遍历链表,依次将结点值进栈.最后在遍历栈出栈. public static Stack<Integer> printListReverse_Stack(ListNode li

  • javaScript 连接打印机,打印小票的实例

    如下所示: <%@ page contentType="text/html;charset=UTF-8"%> <%@ include file="/webpage/include/taglib.jsp"%> <!-- <!DOCTYPE html> --> <html> <head> <meta name="decorator" content="defaul

  • Python logging管理不同级别log打印和存储实例

    Python内置模块logging管理不同级别log打印和存储,非常方便,从此告别了使用print打桩记录,我们来看下logging的魅力吧 import logging logging.basicConfig(level = logging.DEBUG, format = '%(asctime)s %(filename)s[line:%(lineno)d]%(levelname)s %(message)s', datefmt = '%a, %d %b %Y %H:%M:%S', filenam

  • Django 使用logging打印日志的实例

    Django使用python自带的logging 作为日志打印工具.简单介绍下logging. logging 是线程安全的,其主要由4部分组成: Logger 用户使用的直接接口,将日志传递给Handler Handler 控制日志输出到哪里,console,file- 一个logger可以有多个Handler Filter 控制哪些日志可以从logger流向Handler Formatter 控制日志的格式 用户使用logging.getLogger([name])获取logger实例. 如

  • Java编程实现打印螺旋矩阵实例代码

    直接上代码吧. 昨晚腾讯在线测试遇到的题. 螺旋矩阵是指一个呈螺旋状的矩阵,它的数字由第一行开始到右边不断变大,向下变大,向左变大,向上变大,如此循环. import java.util.Scanner; public class mysnakematrix { private int n; // private int a[][]; // 声明一个矩阵 private int value = 1; // 矩阵里数字的值 public mysnakematrix(int i) { this.n

  • 关于pytorch多GPU训练实例与性能对比分析

    以下实验是我在百度公司实习的时候做的,记录下来留个小经验. 多GPU训练 cifar10_97.23 使用 run.sh 文件开始训练 cifar10_97.50 使用 run.4GPU.sh 开始训练 在集群中改变GPU调用个数修改 run.sh 文件 nohup srun --job-name=cf23 $pt --gres=gpu:2 -n1 bash cluster_run.sh $cmd 2>&1 1>>log.cf50_2GPU & 修改 –gres=gpu:

  • 对PyTorch torch.stack的实例讲解

    不是concat的意思 import torch a = torch.ones([1,2]) b = torch.ones([1,2]) torch.stack([a,b],1) (0 ,.,.) = 1 1 1 1 [torch.FloatTensor of size 1x2x2] 以上这篇对PyTorch torch.stack的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们.

  • pytorch 求网络模型参数实例

    用pytorch训练一个神经网络时,我们通常会很关心模型的参数总量.下面分别介绍来两种方法求模型参数 一 .求得每一层的模型参数,然后自然的可以计算出总的参数. 1.先初始化一个网络模型model 比如我这里是 model=cliqueNet(里面是些初始化的参数) 2.调用model的Parameters类获取参数列表 一个典型的操作就是将参数列表传入优化器里.如下 optimizer = optim.Adam(model.parameters(), lr=opt.lr) 言归正传,继续回到参

  • pytorch: Parameter 的数据结构实例

    一般来说,pytorch 的Parameter是一个tensor,但是跟通常意义上的tensor有些不一样 1) 通常意义上的tensor 仅仅是数据 2) 而Parameter所对应的tensor 除了包含数据之外,还包含一个属性:requires_grad(=True/False) 在Parameter所对应的tensor中获取纯数据,可以通过以下操作: param_data = Parameter.data 测试代码: #-*-coding:utf-8-*- import torch im

随机推荐