|
| 1 | +# Custom Operations |
| 2 | + |
| 3 | +Plugins allow you to extend TensorRT with custom operations. |
| 4 | + |
| 5 | +- The **quickly deployable plugin** (QDP) framework is the easiest way to write plugins. |
| 6 | + |
| 7 | + |
| 8 | +## Implementing The Plugin |
| 9 | + |
| 10 | +In this guide, we'll implement a plugin that increments a tensor by 1. |
| 11 | + |
| 12 | +:::{seealso} |
| 13 | +[TensorRT's guide on QDPs](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/pluginGuide.html) |
| 14 | +includes more details on implementing plugins. |
| 15 | +::: |
| 16 | + |
| 17 | +We must: |
| 18 | + |
| 19 | +1. **Register the interface** for the plugin. |
| 20 | +2. Implement the **plugin kernel**. |
| 21 | +3. Generate [**PTX**](https://docs.nvidia.com/cuda/parallel-thread-execution/). |
| 22 | + |
| 23 | + |
| 24 | +### Registering The Plugin Interface |
| 25 | + |
| 26 | +[`trtp.register`](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/infer/tensorrt.plugin/trt_plugin_register.html#tensorrt.plugin.register) |
| 27 | +decorates a function that defines the plugin interface: |
| 28 | + |
| 29 | +```py |
| 30 | +import tensorrt.plugin as trtp |
| 31 | + |
| 32 | +# Plugin IDs are of the form: "<namespace>::<name>" and |
| 33 | +# uniquely identify a plugin. |
| 34 | +INCREMENT_PLUGIN_ID = "example::increment" |
| 35 | + |
| 36 | +@trtp.register(INCREMENT_PLUGIN_ID) |
| 37 | +def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> trtp.TensorDesc: |
| 38 | + """ |
| 39 | + Defines the plugin interface - inputs, outputs and attributes. |
| 40 | +
|
| 41 | + Args: |
| 42 | + inp0: Input tensor descriptor |
| 43 | + block_size: Block size for the Triton kernel |
| 44 | +
|
| 45 | + Returns: |
| 46 | + Output tensor descriptor with same shape/dtype as input |
| 47 | + """ |
| 48 | + return inp0.like() |
| 49 | +``` |
| 50 | + |
| 51 | +### Implementing The Kernel |
| 52 | + |
| 53 | +For this example, we use [OpenAI's Triton language](https://triton-lang.org/main/index.html) |
| 54 | +to implement the kernel: |
| 55 | + |
| 56 | +```py |
| 57 | +import triton |
| 58 | +import triton.language as tl |
| 59 | + |
| 60 | +@triton.jit # doc: ignore-line |
| 61 | +def increment(x_ptr, num_elements, y_ptr, BLOCK_SIZE: tl.constexpr): |
| 62 | + pid = tl.program_id(0) |
| 63 | + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| 64 | + mask = offsets < num_elements |
| 65 | + x = tl.load(x_ptr + offsets, mask=mask) |
| 66 | + tl.store(y_ptr + offsets, x + 1, mask=mask) |
| 67 | +``` |
| 68 | + |
| 69 | +<!-- Tripy: DOC: OMIT Start --> |
| 70 | +<!-- Hack to make source code inspect work - the decorator tries to inspect the source |
| 71 | + code before we have injected it, so we need to invoke it *after* the function definition --> |
| 72 | +```py |
| 73 | +# doc: no-print-locals |
| 74 | +increment.__globals__.update({"tl": tl}) # This is required to make `tl` available during `triton.compile`. |
| 75 | +increment = triton.jit(increment) |
| 76 | +``` |
| 77 | +<!-- Tripy: DOC: OMIT End --> |
| 78 | + |
| 79 | +:::{note} |
| 80 | +Kernels can be written in many other ways, e.g. CUDA, CUTLASS, Numba, etc. as long as we can emit PTX. |
| 81 | +::: |
| 82 | + |
| 83 | + |
| 84 | +### Retrieving PTX |
| 85 | + |
| 86 | +[`trtp.aot_impl`](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/infer/tensorrt.plugin/trt_plugin_aot_impl/index.html#tensorrt.plugin.aot_impl) |
| 87 | +decorates a function that retrieves PTX, launch parameters, and any extra arguments: |
| 88 | + |
| 89 | +```py |
| 90 | +from typing import Tuple, Union |
| 91 | +import tensorrt.plugin as trtp |
| 92 | + |
| 93 | +@trtp.aot_impl(INCREMENT_PLUGIN_ID) |
| 94 | +def increment_aot_impl( |
| 95 | + inp0: trtp.TensorDesc, block_size: int, outputs: Tuple[trtp.TensorDesc], tactic: int |
| 96 | +) -> Tuple[ |
| 97 | + Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs |
| 98 | +]: |
| 99 | + src = triton.compiler.ASTSource( |
| 100 | + fn=increment, |
| 101 | + signature="*fp32,i32,*fp32", |
| 102 | + constants={ |
| 103 | + "BLOCK_SIZE": block_size, |
| 104 | + }, |
| 105 | + ) |
| 106 | + |
| 107 | + compiled_kernel = triton.compile(src) |
| 108 | + |
| 109 | + # Set the grid, block dims and shared memory for the |
| 110 | + # kernel (as symbolic expressions) |
| 111 | + launch_params = trtp.KernelLaunchParams() |
| 112 | + num_elements = inp0.shape_expr.numel() |
| 113 | + launch_params.grid_x = trtp.cdiv(num_elements, block_size) |
| 114 | + launch_params.block_x = compiled_kernel.metadata.num_warps * 32 |
| 115 | + launch_params.shared_mem = compiled_kernel.metadata.shared |
| 116 | + |
| 117 | + # Define extra scalar arguments for the |
| 118 | + # kernel (as symbolic expressions) |
| 119 | + extra_args = trtp.SymIntExprs(1) |
| 120 | + extra_args[0] = trtp.SymInt32(num_elements) |
| 121 | + |
| 122 | + return compiled_kernel.metadata.name, compiled_kernel.asm["ptx"], launch_params, extra_args |
| 123 | +``` |
| 124 | + |
| 125 | + |
| 126 | +## Using The Plugin |
| 127 | + |
| 128 | +We can use the plugin with {func}`nvtripy.plugin`: |
| 129 | + |
| 130 | +```py |
| 131 | +inp = tp.iota((2, 2)) |
| 132 | +# Plugin attributes are passed as keyword arguments and must match |
| 133 | +# the attributes specified by the registration function. |
| 134 | +out = tp.plugin(INCREMENT_PLUGIN_ID, [inp], block_size=256) |
| 135 | +assert tp.equal(out, inp + 1) # doc: omit |
| 136 | +``` |
0 commit comments