Pytorch实现常用乘法算子TensorRT的示例代码

目录
  • 1.乘法运算总览
  • 2.乘法算子实现
    • 2.1矩阵乘算子实现
    • 2.2点乘算子实现

本文介绍一下 Pytorch 中常用乘法的 TensorRT 实现。

pytorch 用于训练,TensorRT 用于推理是很多 AI 应用开发的标配。大家往往更加熟悉 pytorch 的算子,而不太熟悉 TensorRT 的算子,这里拿比较常用的乘法运算在两种框架下的实现做一个对比,可能会有更加直观一些的认识。

1.乘法运算总览

先把 pytorch 中的一些常用的乘法运算进行一个总览:

  • torch.mm:用于两个矩阵 (不包括向量) 的乘法,如维度 (m, n) 的矩阵乘以维度 (n, p) 的矩阵;
  • torch.bmm:用于带 batch 的三维向量的乘法,如维度 (b, m, n) 的矩阵乘以维度 (b, n, p) 的矩阵;
  • torch.mul:用于同维度矩阵的逐像素点相乘,也即点乘,如维度 (m, n) 的矩阵点乘维度 (m, n) 的矩阵。该方法支持广播,也即支持矩阵和元素点乘;
  • torch.mv:用于矩阵和向量的乘法,矩阵在前,向量在后,如维度 (m, n) 的矩阵乘以维度为 (n) 的向量,输出维度为 (m);
  • torch.matmul:用于两个张量相乘,或矩阵与向量乘法,作用包含 torch.mm、torch.bmm、torch.mv;
  • @:作用相当于 torch.matmul;
  • *:作用相当于 torch.mul;

如上进行了一些具体罗列,可以归纳出,常用的乘法无非两种:矩阵乘 和 点乘,所以下面分这两类进行介绍。

2.乘法算子实现

2.1矩阵乘算子实现

先来看看矩阵乘法的 pytorch 的实现 (以下实现在终端):

>>> import torch
>>> # torch.mm
>>> a = torch.randn(66, 99)
>>> b = torch.randn(99, 88)
>>> c = torch.mm(a, b)
>>> c.shape
torch.size([66, 88])
>>>
>>> # torch.bmm
>>> a = torch.randn(3, 66, 99)
>>> b = torch.randn(3, 99, 77)
>>> c = torch.bmm(a, b)
>>> c.shape
torch.size([3, 66, 77])
>>>
>>> # torch.mv
>>> a = torch.randn(66, 99)
>>> b = torch.randn(99)
>>> c = torch.mv(a, b)
>>> c.shape
torch.size([66])
>>>
>>> # torch.matmul
>>> a = torch.randn(32, 3, 66, 99)
>>> b = torch.randn(32, 3, 99, 55)
>>> c = torch.matmul(a, b)
>>> c.shape
torch.size([32, 3, 66, 55])
>>>
>>> # @
>>> d = a @ b
>>> d.shape
torch.size([32, 3, 66, 55])

来看 TensorRT 的实现,以上乘法都可使用 addMatrixMultiply 方法覆盖,对应 torch.matmul,先来看该方法的定义:

//!
//! \brief Add a MatrixMultiply layer to the network.
//!
//! \param input0 The first input tensor (commonly A).
//! \param op0 The operation to apply to input0.
//! \param input1 The second input tensor (commonly B).
//! \param op1 The operation to apply to input1.
//!
//! \see IMatrixMultiplyLayer
//!
//! \warning Int32 tensors are not valid input tensors.
//!
//! \return The new matrix multiply layer, or nullptr if it could not be created.
//!
IMatrixMultiplyLayer* addMatrixMultiply(
  ITensor& input0, MatrixOperation op0, ITensor& input1, MatrixOperation op1) noexcept
{
  return mImpl->addMatrixMultiply(input0, op0, input1, op1);
}

可以看到这个方法有四个传参,对应两个张量和其 operation。来看这个算子在 TensorRT 中怎么添加:

// 构造张量 Tensor0
nvinfer1::IConstantLayer *Constant_layer0 = m_network->addConstant(tensorShape0, value0);
// 构造张量 Tensor1
nvinfer1::IConstantLayer *Constant_layer1 = m_network->addConstant(tensorShape1, value1);

// 添加矩阵乘法
nvinfer1::IMatrixMultiplyLayer *Matmul_layer = m_network->addMatrixMultiply(Constant_layer0->getOutput(0), matrix0Type, Constant_layer1->getOutput(0), matrix2Type);

// 获取输出
matmulOutput = Matmul_layer->getOputput(0);

2.2点乘算子实现

再来看看点乘的 pytorch 的实现 (以下实现在终端):

>>> import torch
>>> # torch.mul
>>> a = torch.randn(66, 99)
>>> b = torch.randn(66, 99)
>>> c = torch.mul(a, b)
>>> c.shape
torch.size([66, 99])
>>> d = 0.125
>>> e = torch.mul(a, d)
>>> e.shape
torch.size([66, 99])
>>> # *
>>> f = a * b
>>> f.shape
torch.size([66, 99])

来看 TensorRT 的实现,以上乘法都可使用 addScale 方法覆盖,这在图像预处理中十分常用,先来看该方法的定义:

//!
//! \brief Add a Scale layer to the network.
//!
//! \param input The input tensor to the layer.
//!              This tensor is required to have a minimum of 3 dimensions in implicit batch mode
//!              and a minimum of 4 dimensions in explicit batch mode.
//! \param mode The scaling mode.
//! \param shift The shift value.
//! \param scale The scale value.
//! \param power The power value.
//!
//! If the weights are available, then the size of weights are dependent on the ScaleMode.
//! For ::kUNIFORM, the number of weights equals 1.
//! For ::kCHANNEL, the number of weights equals the channel dimension.
//! For ::kELEMENTWISE, the number of weights equals the product of the last three dimensions of the input.
//!
//! \see addScaleNd
//! \see IScaleLayer
//! \warning Int32 tensors are not valid input tensors.
//!
//! \return The new Scale layer, or nullptr if it could not be created.
//!
IScaleLayer* addScale(ITensor& input, ScaleMode mode, Weights shift, Weights scale, Weights power) noexcept
{
  return mImpl->addScale(input, mode, shift, scale, power);
}

可以看到有三个模式:

  • kUNIFORM:weights 为一个值,对应张量乘一个元素;
  • kCHANNEL:weights 维度和输入张量通道的 c 维度对应,可以做一些以通道为基准的预处理;
  • kELEMENTWISE:weights 维度和输入张量的 c、h、w 对应,不考虑 batch,所以是输入的后三维;

再来看这个算子在 TensorRT 中怎么添加:

// 构造张量 input
nvinfer1::IConstantLayer *Constant_layer = m_network->addConstant(tensorShape, value);

// scalemode选择,kUNIFORM、kCHANNEL、kELEMENTWISE
scalemode = kUNIFORM;

// 构建 Weights 类型的 shift、scale、power,其中 volume 为元素数量
nvinfer1::Weights scaleShift{nvinfer1::DataType::kFLOAT, nullptr, volume };
nvinfer1::Weights scaleScale{nvinfer1::DataType::kFLOAT, nullptr, volume };
nvinfer1::Weights scalePower{nvinfer1::DataType::kFLOAT, nullptr, volume };

// !! 注意这里还需要对 shift、scale、power 的 values 进行赋值,若只是乘法只需要对 scale 进行赋值就行

// 添加张量乘法
nvinfer1::IScaleLayer *Scale_layer = m_network->addScale(Constant_layer->getOutput(0), scalemode, scaleShift, scaleScale, scalePower);

// 获取输出
scaleOutput = Scale_layer->getOputput(0);

有一点你可能会比较疑惑,既然是点乘,那么输入只需要两个张量就可以了,为啥这里有 input、shift、scale、power 四个张量这么多呢。解释一下,input 不用说,就是输入张量,而 shift 表示加法参数、scale 表示乘法参数、power 表示指数参数,说到这里,你应该能发现,这个函数除了我们上面讲的点乘外还有其他更加丰富的运算功能。

到此这篇关于Pytorch实现常用乘法算子TensorRT的示例代码的文章就介绍到这了,更多相关Pytorch乘法算子TensorRT内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • Python api构建tensorrt加速模型的步骤详解

    目录 一.创建TensorRT有以下几个步骤: 二.Python api和C++ api在实现网络加速有什么区别? 三.构建TensorRT加速模型 3.1 加载tensorRT 3.2 创建网络 3.3 ONNX构建engine 一.创建TensorRT有以下几个步骤: 1.用TensorRT中network模块定义网络模型 2.调用TensorRT构建器从网络创建优化的运行时引擎 3.采用序列化和反序列化操作以便在运行时快速重建 4.将数据喂入engine中进行推理 二.Python api

  • Pytorch通过保存为ONNX模型转TensorRT5的实现

    1 Pytorch以ONNX方式保存模型 def saveONNX(model, filepath): ''' 保存ONNX模型 :param model: 神经网络模型 :param filepath: 文件保存路径 ''' # 神经网络输入数据类型 dummy_input = torch.randn(self.config.BATCH_SIZE, 1, 28, 28, device='cuda') torch.onnx.export(model, dummy_input, filepath,

  • PyTorch模型转TensorRT是怎么实现的?

    转换步骤概览 准备好模型定义文件(.py文件) 准备好训练完成的权重文件(.pth或.pth.tar) 安装onnx和onnxruntime 将训练好的模型转换为.onnx格式 安装tensorRT 环境参数 ubuntu-18.04 PyTorch-1.8.1 onnx-1.9.0 onnxruntime-1.7.2 cuda-11.1 cudnn-8.2.0 TensorRT-7.2.3.4 PyTorch转ONNX Step1:安装ONNX和ONNXRUNTIME 网上找到的安装方式是通过

  • Pytorch实现常用乘法算子TensorRT的示例代码

    目录 1.乘法运算总览 2.乘法算子实现 2.1矩阵乘算子实现 2.2点乘算子实现 本文介绍一下 Pytorch 中常用乘法的 TensorRT 实现. pytorch 用于训练,TensorRT 用于推理是很多 AI 应用开发的标配.大家往往更加熟悉 pytorch 的算子,而不太熟悉 TensorRT 的算子,这里拿比较常用的乘法运算在两种框架下的实现做一个对比,可能会有更加直观一些的认识. 1.乘法运算总览 先把 pytorch 中的一些常用的乘法运算进行一个总览: torch.mm:用于

  • Unity常用音频操作类示例代码

    下面通过代码给大家介绍Unity常用音频操作类,具体代码如下所示: using UnityEngine; using System.Collections; public class AudioPlay : MonoBehaviour { public static AudioPlay Instance; public AudioClip[] FuChuAudio; public AudioSource FCAudio; // public AudioSource BabyAudio; // U

  • Go语言实现常用排序算法的示例代码

    目录 冒泡排序 快速排序 选择排序 插入排序 排序算法是在生活中随处可见,也是算法基础,因为其实现代码较短,应用较常见.所以在面试中经常会问到排序算法及其相关的问题,可以说是每个程序员都必须得掌握的了.为了方便大家学习,花了一天的时间用Go语言实现一下常用的算法且整理了一下,如有需要可以参考. 冒泡排序 思路:从前往后对相邻的两个元素依次进行比较,让较大的数往下沉,较小的网上冒,即每当两个相邻的元素比较后发现他们的排序要求相反时,就将它们互换. 时间复杂度:O(N^2) 空间复杂度:O(1) f

  • nginx常用配置conf的示例代码详解

    nginx常用配置conf 代理静态文件 # 静态文件 server { # 压缩问价你配置 gzip on; gzip_min_length 1k; gzip_buffers 4 16k; gzip_http_version 1.1; gzip_comp_level 6; gzip_types text/plain text/css application/javascript application/json image/jpeg image/png image/gif; gzip_disa

  • pytorch 可视化feature map的示例代码

    之前做的一些项目中涉及到feature map 可视化的问题,一个层中feature map的数量往往就是当前层out_channels的值,我们可以通过以下代码可视化自己网络中某层的feature map,个人感觉可视化feature map对调参还是很有用的. 不多说了,直接看代码: import torch from torch.autograd import Variable import torch.nn as nn import pickle from sys import path

  • 超详细PyTorch实现手写数字识别器的示例代码

    前言 深度学习中有很多玩具数据,mnist就是其中一个,一个人能否入门深度学习往往就是以能否玩转mnist数据来判断的,在前面很多基础介绍后我们就可以来实现一个简单的手写数字识别的网络了 数据的处理 我们使用pytorch自带的包进行数据的预处理 import torch import torchvision import torchvision.transforms as transforms import numpy as np import matplotlib.pyplot as plt

  • Java常用工具类汇总 附示例代码

    一.FileUtils private static void fileUtilsTest() { try { //读取文件内容 String readFileToString = FileUtils.readFileToString(new File("D:\\guor\\data\\test20211022000000.txt")); System.out.println(readFileToString); //删除文件夹 FileUtils.deleteDirectory(ne

  • pytorch教程网络和损失函数的可视化代码示例

    目录 1.效果 2.环境 3.用到的代码 1.效果 2.环境 1.pytorch 2.visdom 3.python3.5 3.用到的代码 # coding:utf8 import torch from torch import nn, optim # nn 神经网络模块 optim优化函数模块 from torch.utils.data import DataLoader from torch.autograd import Variable from torchvision import t

  • PyTorch实现手写数字识别的示例代码

    目录 加载手写数字的数据 数据加载器(分批加载) 建立模型 模型训练 测试集抽取数据,查看预测结果 计算模型精度 自己手写数字进行预测 加载手写数字的数据 组成训练集和测试集,这里已经下载好了,所以download为False import torchvision # 是否支持gpu运算 # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # print(device) # print(torch.cud

  • Python常用工具类之adbtool示例代码

    1.adb常用命令 关闭adb服务:adb kill-server 启动adb服务  adb start-server 查询当前运行的所有设备  adb devices 可能在adb中存在多个虚拟设备运行 可以指定虚拟设备运行  -s 虚拟设备名称 重启设备 adb reboot  --指定虚拟设备   adb -s 设备名称 reboot 查看日志  adb logcat  清除日志 adb logcat -c 进入linux shell下  adb shell 其中常用的linux命令  c

随机推荐