Super雨

其实没有那么多观众,大胆地自由地生活!

TensorRT插件基础教程基于TensorRT8[1]

TensorRT插件零基础教程(1)(基于TensorRT8)

一、写在前面

引言

TensorRT 是当前少有的基于 GPU 的高速推理框架,是网络部署到生产环境中的不二之选。当下流行的使用方式是将网络模型从训练框架下导出为 ONNX 格式的通用神经网络交换格式,然后经由 TensorRT 官方提供的 OnnxParser 解析为 TensorRT 认识的网络定义,再进行网络 engine 的构建。然而,在所难免的问题是,由于人工智能领域常常会出现新的算子,或者说用户本身想在前传过程中实现自定义的行为,此时只能借助插件,于是,插件的编写成了模型部署阶段的巨大绊脚石。

可以说,如果不会编写插件,就不是真正的会用 TensorRT 这个框架,插件才是使用 TensorRT 的灵魂所在。但是 TensorRT 的使用本身就有一定的门槛,编写插件更是比较困难,原因在于:

  1. 可能刚接触神经网络的人能够借助大量的网络资源在很短时间就搭建出网络模型,但是对网络本身对输入数据的操作过程不是很理解,编写插件要求用户清楚的知道要对数据做什么样的操作;
  2. 虽然 TensorRT 提供了较为直观的 Python API,但是在插件编写时绕不开 C++ 编程,而且要对面向对象要有一定的理解;
  3. 某些加速操作要用到 GPU 运算,插件要嵌入 CUDA 代码,以操作 GPU;
  4. TensorRT 官方的文档并未详细说明插件怎么编写,用户需要参考官方提供的示例编写自己的插件。

本系列教程会从插件编写过程中用到的重要的数据结构开始,以 GroupNoamalization 这个插件为例,详细说明插件的编写流程。

读完本篇,你将获得

对插件的构建过程有一个大致的认识。

二、步入正题

1、一些重要的数据结构

描述插件 input 或者 output 张量的描述符,官方称为 fields:

1
2
3
4
5
6
7
8
9
10
11
struct PluginTensorDesc  
{
//! Dimensions,维度,包含维数与每个维度的元素个数
Dims dims;
//! \warning DataType:kBOOL not supported.数据类型,float、int32等
DataType type;
//! Tensor format.
TensorFormat format;
//! Scale for INT8 data type.
float scale;
};

2、插件的大致实现流程

TensorRT 官方定义了一些关于插件的基类,这些基类中定义了一些解析插件、创建插件所必须的函数接口,只要用户实现了这些函数,那么在 TensorRT 的机制下就能顺利创建插件,并在engine中使用插件。两个重要的类是:nvinfer1::IPluginV2DynamicExtIPluginCreator,前者有多个版本,在8.0版本我们使用IPluginV2DynamicExt即可。这两个插件中有很多函数,但是只有很少的一部分是需要我们细心实现的,所以刚看到这么一堆庞大的代码,不要产生厌烦心理。

继承了nvinfer1::IPluginV2DynamicExtGroupNormalizationPlugin类需要认真实现上述着色的函数,绿色的是构造函数,其他函数照猫画虎很好写,继承了IPluginCreatorGroupNormalizationPluginCreator也仅仅有3个函数要认真实现:

3、Tips

由于TensorRT并不开源,所以其很多实现的细节我们都看不到,对一些概念也难以琢磨到其本质的实现,但是根据其暴露的部分,我们可以猜到一些端倪,带着这些直觉来编写插件,可能会觉得更自然。

  1. 插件要实现具体的运算,所以要定义你具体的运算过程,这个运算过程应当是一个函数,可以被enqueue调用;
  2. 插件要接受一些输入或者参数,这些参数要保存,正是保存在builder生成的engine文件里,写入engine是序列化(serialize)读出参数是反序列化(deserilize),自然的,你需要给TensorRT提供好这些函数;
  3. 插件可能在网络里用到多次,提供给TensorRT一个clone函数更合适;

由此看来,实现插件其实并不是很复杂,他要什么用户就提供什么即可。

三、写在最后

这篇是一个前奏,下一篇说明两个派生类的内部的函数具体是做什么的以及怎么实现。