File size: 2,245 Bytes
e0c2d04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# TensorRT usage in Python script

There are 2 ways to use a TensorRT optimized model:

* deploy it on Triton server
* use it directly in Python

This document is about the second option.

## High-level explanations

* call `load_engine()` to parse an existing TensorRT engine or `build_engine()` to convert an ONNX file
* setup a CUDA `stream` (for async call), a TensorRT `runtime` and a `context`
* load your `profile`(s)
* call `infer_tensorrt()`

## Build engine

We assume that you have already prepared your ONNX file.  
Now we need to convert to TensorRT:

```python
import tensorrt as trt
from tensorrt.tensorrt import Logger, Runtime

from transformer_deploy.backends.trt_utils import build_engine

trt_logger: Logger = trt.Logger(trt.Logger.ERROR)
runtime: Runtime = trt.Runtime(trt_logger)
profile_index = 0
max_seq_len = 256
batch_size = 32

engine = build_engine(
    runtime=runtime,
    onnx_file_path="model_qat.onnx",
    logger=trt_logger,
    min_shape=(1, max_seq_len),
    optimal_shape=(batch_size, max_seq_len),
    max_shape=(batch_size, max_seq_len),
    workspace_size=10000 * 1024 * 1024,
    fp16=True,
    int8=True,
)
```

## Prepare inference

Now the engine is ready, we can prepare the inference:

```python
import torch
from tensorrt.tensorrt import IExecutionContext

from transformer_deploy.backends.trt_utils import get_binding_idxs

context: IExecutionContext = engine.create_execution_context()
context.set_optimization_profile_async(profile_index=profile_index, stream_handle=torch.cuda.current_stream().cuda_stream)
input_binding_idxs, output_binding_idxs = get_binding_idxs(engine, profile_index)  # type: List[int], List[int]
```

## Inference

```python


from transformer_deploy.backends.trt_utils import infer_tensorrt

input_np = ...

tensorrt_output = infer_tensorrt(
    context=context,
    host_inputs=input_np,
    input_binding_idxs=input_binding_idxs,
    output_binding_idxs=output_binding_idxs,
)
print(tensorrt_output)
```

... and you are done! 🎉

!!! tip

    To go deeper, check in the API:

    * `Convert`
    * `Backends/Trt utils`

    ... and if you are looking for inspiration, check [onnx-tensorrt](https://github.com/onnx/onnx-tensorrt)

--8<-- "resources/abbreviations.md"