Python torch.onnx.export用法详细介绍

目录
  • 函数原型
  • 参数介绍
    • mode (torch.nn.Module, torch.jit.ScriptModule or torch.jit.ScriptFunction)
    • args (tuple or torch.Tensor)
    • f
    • export_params (bool, default True)
    • verbose (bool, default False)
    • training (enum, default TrainingMode.EVAL)
    • input_names (list of str, default empty list)
    • output_names (list of str, default empty list)
    • operator_export_type (enum, default None)
    • opset_version (int, default 9)
    • do_constant_folding (bool, default False)
    • example_outputs (T or a tuple of T, where T is Tensor or convertible to Tensor, default None)
    • dynamic_axes (dict<string, dict<python:int, string>> or dict<string, list(int)>, default empty dict)
    • keep_initializers_as_inputs (bool, default None)
    • custom_opsets (dict<str, int>, default empty dict)
  • Torch.onnx.export执行流程:
  • 总结

函数原型

参数介绍

mode (torch.nn.Module, torch.jit.ScriptModule or torch.jit.ScriptFunction)

需要转换的模型,支持的模型类型有:torch.nn.Module, torch.jit.ScriptModule or torch.jit.ScriptFunction

args (tuple or torch.Tensor)

args可以被设置成三种形式

1.一个tuple

args = (x, y, z)

这个tuple应该与模型的输入相对应,任何非Tensor的输入都会被硬编码入onnx模型,所有Tensor类型的参数会被当做onnx模型的输入。

2.一个Tensor

args = torch.Tensor([1, 2, 3])

一般这种情况下模型只有一个输入

3.一个带有字典的tuple

args = (x,
        {'y': input_y,
         'z': input_z})

这种情况下,所有字典之前的参数会被当做“非关键字”参数传入网络,字典种的键值对会被当做关键字参数传入网络。如果网络中的关键字参数未出现在此字典中,将会使用默认值,如果没有设定默认值,则会被指定为None。

NOTE:

一个特殊情况,当网络本身最后一个参数为字典时,直接在tuple最后写一个字典则会被误认为关键字传参。所以,可以通过在tuple最后添加一个空字典来解决。

#错误写法:

torch.onnx.export(
    model,
    (x,
     # WRONG: will be interpreted as named arguments
     {y: z}),
    "test.onnx.pb")

# 纠正

torch.onnx.export(
    model,
    (x,
     {y: z},
     {}),
    "test.onnx.pb")

f

一个文件类对象或一个路径字符串,二进制的protocol buffer将被写入此文件

export_params (bool, default True)

如果为True则导出模型的参数。如果想导出一个未训练的模型,则设为False

verbose (bool, default False)

如果为True,则打印一些转换日志,并且onnx模型中会包含doc_string信息。

training (enum, default TrainingMode.EVAL)

枚举类型包括:

TrainingMode.EVAL - 以推理模式导出模型。

TrainingMode.PRESERVE - 如果model.training为False,则以推理模式导出;否则以训练模式导出。

TrainingMode.TRAINING - 以训练模式导出,此模式将禁止一些影响训练的优化操作。

input_names (list of str, default empty list)

按顺序分配给onnx图的输入节点的名称列表。

output_names (list of str, default empty list)

按顺序分配给onnx图的输出节点的名称列表。

operator_export_type (enum, default None)

默认为OperatorExportTypes.ONNX, 如果Pytorch built with DPYTORCH_ONNX_CAFFE2_BUNDLE,则默认为OperatorExportTypes.ONNX_ATEN_FALLBACK。

枚举类型包括:

OperatorExportTypes.ONNX - 将所有操作导出为ONNX操作。

OperatorExportTypes.ONNX_FALLTHROUGH - 试图将所有操作导出为ONNX操作,但碰到无法转换的操作(如onnx未实现的操作),则将操作导出为“自定义操作”,为了使导出的模型可用,运行时必须支持这些自定义操作。支持自定义操作方法见链接

OperatorExportTypes.ONNX_ATEN - 所有ATen操作导出为ATen操作,ATen是Pytorch的内建tensor库,所以这将使得模型直接使用Pytorch实现。(此方法转换的模型只能被Caffe2直接使用)

OperatorExportTypes.ONNX_ATEN_FALLBACK - 试图将所有的ATen操作也转换为ONNX操作,如果无法转换则转换为ATen操作(此方法转换的模型只能被Caffe2直接使用)。例如:

# 转换前:
graph(%0 : Float):
  %3 : int = prim::Constant[value=0]()
  # conversion unsupported
  %4 : Float = aten::triu(%0, %3)
  # conversion supported
  %5 : Float = aten::mul(%4, %0)
  return (%5)

# 转换后:
graph(%0 : Float):
  %1 : Long() = onnx::Constant[value={0}]()
  # not converted
  %2 : Float = aten::ATen[operator="triu"](%0, %1)
  # converted
  %3 : Float = onnx::Mul(%2, %0)
  return (%3)

opset_version (int, default 9)

默认是9。值必须等于_onnx_main_opset或在_onnx_stable_opsets之内。具体可在torch/onnx/symbolic_helper.py中找到。例如:

_default_onnx_opset_version = 9

_onnx_main_opset = 13

_onnx_stable_opsets = [7, 8, 9, 10, 11, 12]

_export_onnx_opset_version = _default_onnx_opset_version

do_constant_folding (bool, default False)

是否使用“常量折叠”优化。常量折叠将使用一些算好的常量来优化一些输入全为常量的节点。

example_outputs (T or a tuple of T, where T is Tensor or convertible to Tensor, default None)

当需输入模型为ScriptModule 或 ScriptFunction时必须提供。此参数用于确定输出的类型和形状,而不跟踪(tracing )模型的执行。

dynamic_axes (dict<string, dict<python:int, string>> or dict<string, list(int)>, default empty dict)

通过以下规则设置动态的维度:

KEY(str) - 必须是input_names或output_names指定的名称,用来指定哪个变量需要使用到动态尺寸。

VALUE(dict or list) - 如果是一个dict,dict中的key是变量的某个维度,dict中的value是我们给这个维度取的名称。如果是一个list,则list中的元素都表示此变量的某个维度。

具体可参考如下示例:

class SumModule(torch.nn.Module):
    def forward(self, x):
        return torch.sum(x, dim=1)

# 以动态尺寸模式导出模型

torch.onnx.export(SumModule(), (torch.ones(2, 2),), "onnx.pb",
                  input_names=["x"], output_names=["sum"],
                  dynamic_axes={
                      # dict value: manually named axes
                      "x": {0: "my_custom_axis_name"},
                      # list value: automatic names
                      "sum": [0],
                  })

### 导出后的节点信息

##input

input {
  name: "x"
  ...
      shape {
        dim {
          dim_param: "my_custom_axis_name"  # axis 0
        }
        dim {
          dim_value: 2  # axis 1
...

##output
output {
  name: "sum"
  ...
      shape {
        dim {
          dim_param: "sum_dynamic_axes_1"  # axis 0
...
 

keep_initializers_as_inputs (bool, default None)

NONE

custom_opsets (dict<str, int>, default empty dict)

NONE

Torch.onnx.export执行流程:

1、如果输入到torch.onnx.export的模型是nn.Module类型,则默认会将模型使用torch.jit.trace转换为ScriptModule

2、使用args参数和torch.jit.trace将模型转换为ScriptModule,torch.jit.trace不能处理模型中的循环和if语句

3、如果模型中存在循环或者if语句,在执行torch.onnx.export之前先使用torch.jit.script将nn.Module转换为ScriptModule

4、模型转换成onnx之后,预测结果与之前会有稍微的差别,这些差别往往不会改变模型的预测结果,比如预测的概率在小数点之后五六位有差别。

总结

到此这篇关于Python torch.onnx.export用法详细介绍的文章就介绍到这了,更多相关Python torch.onnx.export用法内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • Python torch.onnx.export用法详细介绍

    目录 函数原型 参数介绍 mode (torch.nn.Module, torch.jit.ScriptModule or torch.jit.ScriptFunction) args (tuple or torch.Tensor) f export_params (bool, default True) verbose (bool, default False) training (enum, default TrainingMode.EVAL) input_names (list of st

  • Python进阶之高级用法详细总结

    一.Lambda表达式 Lambda表达式又被称之为匿名函数 格式 lambda 参数列表:函数体 def add(x,y): return x+y print(add(3,4)) #上面的函数可以写成Lambda函数 add_lambda=lambda x,y:x+y add_lambda(3,4) 二.map函数 函数就是有输入和输出,map的输入和输出对应关系如下图所示: 就是要把一个可迭代的对象按某个规则映射到新的对象上. 因此map函数要有两个参数,一个是映射规则,一个是可迭代对象.

  • python+mongodb数据抓取详细介绍

    分享点干货!!! Python数据抓取分析 编程模块:requests,lxml,pymongo,time,BeautifulSoup 首先获取所有产品的分类网址: def step(): try: headers = { ..... } r = requests.get(url,headers,timeout=30) html = r.content soup = BeautifulSoup(html,"lxml") url = soup.find_all(正则表达式) for i

  • Python 通过pip安装Django详细介绍

    Python 通过pip安装Django详细介绍 经过前面的 Python 包管理工具的学习,接下来我们就要基于前面的知识,来配置 Django 的开发与运行环境. 首先是安装 Django(通过pip安装): pip install Django 输出的结果在我这里是这样的: Downloading/unpacking Django Downloading Django-1.5.2.tar.gz (8.0MB): 8.0MB downloaded Running setup.py egg_in

  • nginx命令参数用法详细介绍

    nginx命令参数用法详细介绍 nginx命令:启动nginx 在Windows上安装好nginx后,我们需要启动nginx服务,启动nginx服务的命令行操作主要有两种方式,即 C:/nginx-0.8.53>nginx.exe 或者 C:/nginx-0.8.53>start nginx 启动nginx命令说明:需要注意,由于nginx默认端口也是80端口,如果此时你的机器上开启了Apache或者IIS服务,切忌在启动nginx之前务必关闭IIS或Apache服务,否则nginx启动命令不

  • python实现微信接口(itchat)详细介绍

    前言 itchat是一个开源的微信个人号接口,使用python调用微信从未如此简单.使用不到三十行的代码,你就可以完成一个能够处理所有信息的微信机器人.当然,该api的使用远不止一个机器人,更多的功能等着你来发现,比如这些.该接口与公众号接口itchatmp共享类似的操作方式,学习一次掌握两个工具.如今微信已经成为了个人社交的很大一部分,希望这个项目能够帮助你扩展你的个人的微信号.方便自己的生活. 安装 sudo pip install itchat 登录 itchat.auto_login()

  • PHP中error_reporting函数用法详细介绍

    PHP中error_reporting函数用法详细介绍 PHP中对错误的处理会用到error_reporting函数,看到最多的是error_reporting(E_ALL ^ E_NOTICE),这个是什么意思呢?下面我们具体分析error_reporting函数. 定义用法 error_reporting() 设置 PHP 的报错级别并返回当前级别. 语法 error_reporting(report_level) 如果参数 report_level 未指定,当前报错级别将被返回.下面几项是

  • Java多线程的用法详细介绍

    Java多线程的用法详细介绍 最全面的Java多线程用法解析,如果你对Java的多线程机制并没有深入的研究,那么本文可以帮助你更透彻地理解Java多线程的原理以及使用方法. 1.创建线程 在Java中创建线程有两种方法:使用Thread类和使用Runnable接口.在使用Runnable接口时需要建立一个Thread实例.因此,无论是通过Thread类还是Runnable接口建立线程,都必须建立Thread类或它的子类的实例.Thread构造函数: public Thread( ); publi

  • Spring中@Transactional用法详细介绍

    Spring中@Transactional用法详细介绍 引言: 在spring中@Transactional提供一种控制事务管理的快捷手段,但是很多人都只是@Transactional简单使用,并未深入了解,其各个配置项的使用方法,本文将深入讲解各个配置项的使用. 1.  @Transactional的定义 Spring中的@Transactional基于动态代理的机制,提供了一种透明的事务管理机制,方便快捷解决在开发中碰到的问题.在现实中,实际的问题往往比我们预期的要复杂很多,这就要求对@Tr

  • Python各种类型装饰器详细介绍

    目录 装饰器说明 装饰器分类 最简单的装饰器 用于修改对象的装饰器 用于模拟对象的装饰器--函数装饰器 用于模拟对象的装饰器--类方法装饰器 用于模拟对象的装饰器--类装饰器 特殊应用的装饰器 类实现的装饰器 装饰带参数/返回值的对象 装饰器带参数 装饰器应用 装饰器说明 Python中的装饰器是一种可以装饰其它对象的工具.该工具本质上是一个可调用的对象(callable),所以装饰器一般可以由函数.类来实现.装饰器本身需要接受一个被装饰的对象作为参数,该参数通常为函数.方法.类等对象.装饰器需

随机推荐