Pytorch实现常用乘法算子TensorRT的示例代码
目录
- 1.乘法运算总览
- 2.乘法算子实现
- 2.1矩阵乘算子实现
- 2.2点乘算子实现
本文介绍一下 Pytorch 中常用乘法的 TensorRT 实现。
pytorch 用于训练,TensorRT 用于推理是很多 AI 应用开发的标配。大家往往更加熟悉 pytorch 的算子,而不太熟悉 TensorRT 的算子,这里拿比较常用的乘法运算在两种框架下的实现做一个对比,可能会有更加直观一些的认识。
1.乘法运算总览
先把 pytorch 中的一些常用的乘法运算进行一个总览:
- torch.mm:用于两个矩阵 (不包括向量) 的乘法,如维度 (m, n) 的矩阵乘以维度 (n, p) 的矩阵;
- torch.bmm:用于带 batch 的三维向量的乘法,如维度 (b, m, n) 的矩阵乘以维度 (b, n, p) 的矩阵;
- torch.mul:用于同维度矩阵的逐像素点相乘,也即点乘,如维度 (m, n) 的矩阵点乘维度 (m, n) 的矩阵。该方法支持广播,也即支持矩阵和元素点乘;
- torch.mv:用于矩阵和向量的乘法,矩阵在前,向量在后,如维度 (m, n) 的矩阵乘以维度为 (n) 的向量,输出维度为 (m);
- torch.matmul:用于两个张量相乘,或矩阵与向量乘法,作用包含 torch.mm、torch.bmm、torch.mv;
- @:作用相当于 torch.matmul;
- *:作用相当于 torch.mul;
如上进行了一些具体罗列,可以归纳出,常用的乘法无非两种:矩阵乘 和 点乘,所以下面分这两类进行介绍。
2.乘法算子实现
2.1矩阵乘算子实现
先来看看矩阵乘法的 pytorch 的实现 (以下实现在终端):
>>> import torch >>> # torch.mm >>> a = torch.randn(66, 99) >>> b = torch.randn(99, 88) >>> c = torch.mm(a, b) >>> c.shape torch.size([66, 88]) >>> >>> # torch.bmm >>> a = torch.randn(3, 66, 99) >>> b = torch.randn(3, 99, 77) >>> c = torch.bmm(a, b) >>> c.shape torch.size([3, 66, 77]) >>> >>> # torch.mv >>> a = torch.randn(66, 99) >>> b = torch.randn(99) >>> c = torch.mv(a, b) >>> c.shape torch.size([66]) >>> >>> # torch.matmul >>> a = torch.randn(32, 3, 66, 99) >>> b = torch.randn(32, 3, 99, 55) >>> c = torch.matmul(a, b) >>> c.shape torch.size([32, 3, 66, 55]) >>> >>> # @ >>> d = a @ b >>> d.shape torch.size([32, 3, 66, 55])
来看 TensorRT 的实现,以上乘法都可使用 addMatrixMultiply
方法覆盖,对应 torch.matmul,先来看该方法的定义:
//! //! \brief Add a MatrixMultiply layer to the network. //! //! \param input0 The first input tensor (commonly A). //! \param op0 The operation to apply to input0. //! \param input1 The second input tensor (commonly B). //! \param op1 The operation to apply to input1. //! //! \see IMatrixMultiplyLayer //! //! \warning Int32 tensors are not valid input tensors. //! //! \return The new matrix multiply layer, or nullptr if it could not be created. //! IMatrixMultiplyLayer* addMatrixMultiply( ITensor& input0, MatrixOperation op0, ITensor& input1, MatrixOperation op1) noexcept { return mImpl->addMatrixMultiply(input0, op0, input1, op1); }
可以看到这个方法有四个传参,对应两个张量和其 operation
。来看这个算子在 TensorRT 中怎么添加:
// 构造张量 Tensor0 nvinfer1::IConstantLayer *Constant_layer0 = m_network->addConstant(tensorShape0, value0); // 构造张量 Tensor1 nvinfer1::IConstantLayer *Constant_layer1 = m_network->addConstant(tensorShape1, value1); // 添加矩阵乘法 nvinfer1::IMatrixMultiplyLayer *Matmul_layer = m_network->addMatrixMultiply(Constant_layer0->getOutput(0), matrix0Type, Constant_layer1->getOutput(0), matrix2Type); // 获取输出 matmulOutput = Matmul_layer->getOputput(0);
2.2点乘算子实现
再来看看点乘的 pytorch 的实现 (以下实现在终端):
>>> import torch >>> # torch.mul >>> a = torch.randn(66, 99) >>> b = torch.randn(66, 99) >>> c = torch.mul(a, b) >>> c.shape torch.size([66, 99]) >>> d = 0.125 >>> e = torch.mul(a, d) >>> e.shape torch.size([66, 99]) >>> # * >>> f = a * b >>> f.shape torch.size([66, 99])
来看 TensorRT 的实现,以上乘法都可使用 addScale
方法覆盖,这在图像预处理中十分常用,先来看该方法的定义:
//! //! \brief Add a Scale layer to the network. //! //! \param input The input tensor to the layer. //! This tensor is required to have a minimum of 3 dimensions in implicit batch mode //! and a minimum of 4 dimensions in explicit batch mode. //! \param mode The scaling mode. //! \param shift The shift value. //! \param scale The scale value. //! \param power The power value. //! //! If the weights are available, then the size of weights are dependent on the ScaleMode. //! For ::kUNIFORM, the number of weights equals 1. //! For ::kCHANNEL, the number of weights equals the channel dimension. //! For ::kELEMENTWISE, the number of weights equals the product of the last three dimensions of the input. //! //! \see addScaleNd //! \see IScaleLayer //! \warning Int32 tensors are not valid input tensors. //! //! \return The new Scale layer, or nullptr if it could not be created. //! IScaleLayer* addScale(ITensor& input, ScaleMode mode, Weights shift, Weights scale, Weights power) noexcept { return mImpl->addScale(input, mode, shift, scale, power); }
可以看到有三个模式:
- kUNIFORM:weights 为一个值,对应张量乘一个元素;
- kCHANNEL:weights 维度和输入张量通道的 c 维度对应,可以做一些以通道为基准的预处理;
- kELEMENTWISE:weights 维度和输入张量的 c、h、w 对应,不考虑 batch,所以是输入的后三维;
再来看这个算子在 TensorRT 中怎么添加:
// 构造张量 input nvinfer1::IConstantLayer *Constant_layer = m_network->addConstant(tensorShape, value); // scalemode选择,kUNIFORM、kCHANNEL、kELEMENTWISE scalemode = kUNIFORM; // 构建 Weights 类型的 shift、scale、power,其中 volume 为元素数量 nvinfer1::Weights scaleShift{nvinfer1::DataType::kFLOAT, nullptr, volume }; nvinfer1::Weights scaleScale{nvinfer1::DataType::kFLOAT, nullptr, volume }; nvinfer1::Weights scalePower{nvinfer1::DataType::kFLOAT, nullptr, volume }; // !! 注意这里还需要对 shift、scale、power 的 values 进行赋值,若只是乘法只需要对 scale 进行赋值就行 // 添加张量乘法 nvinfer1::IScaleLayer *Scale_layer = m_network->addScale(Constant_layer->getOutput(0), scalemode, scaleShift, scaleScale, scalePower); // 获取输出 scaleOutput = Scale_layer->getOputput(0);
有一点你可能会比较疑惑,既然是点乘,那么输入只需要两个张量就可以了,为啥这里有 input、shift、scale、power 四个张量这么多呢。解释一下,input 不用说,就是输入张量,而 shift 表示加法参数、scale 表示乘法参数、power 表示指数参数,说到这里,你应该能发现,这个函数除了我们上面讲的点乘外还有其他更加丰富的运算功能。
到此这篇关于Pytorch实现常用乘法算子TensorRT的示例代码的文章就介绍到这了,更多相关Pytorch乘法算子TensorRT内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!