Skip to content

Commit 987ba8d

Browse files
Adds support for source code inspection in guide examples, QDP tests
`triton` requires source code inspection in order to compile the kernel. Since the generated docs dynamically execute code blocks, this is normally not possible. This commit adds a special `ExecNamespace` that will inject source code information as functions are defined. This commit also adds various tests for QDPs and refactors the guide and some of the code.
1 parent 169b1ee commit 987ba8d

File tree

13 files changed

+383
-431
lines changed

13 files changed

+383
-431
lines changed

tripy/docs/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ Code blocks in docstrings/guides are **preprocessed**:
160160
161161
- `# doc: no-eval` disables execution but this means the code will be **untested**!
162162
163+
- `# doc: ignore-line` disables execution of the indicated line but still includes it in the rendered code.
164+
163165
- Local variables are also displayed. You can customize this:
164166
165167
- **Include** only specific variables: `# doc: print-locals <var1> <var2> ...`

tripy/docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@
157157
def process_docstring_impl(app, what, name, obj, options, lines):
158158
doc = "\n".join(lines).strip()
159159
blocks = helper.consolidate_code_blocks(doc)
160-
name = name.lstrip("nvtripy.")
160+
name = name.rpartition("nvtripy.")[-1]
161161

162162
# Check signature for functions/methods and class constructors.
163163
if what in {"function", "method"} or (what == "class" and name in seen_classes):
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Custom Operations
2+
3+
Plugins allow you to extend TensorRT with custom operations.
4+
5+
- The **quickly deployable plugin** (QDP) framework is the easiest way to write plugins.
6+
7+
8+
## Implementing The Plugin
9+
10+
In this guide, we'll implement a plugin that increments a tensor by 1.
11+
12+
:::{seealso}
13+
[TensorRT's guide on QDPs](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/pluginGuide.html)
14+
includes more details on implementing plugins.
15+
:::
16+
17+
We must:
18+
19+
1. **Register the interface** for the plugin.
20+
2. Implement the **plugin kernel**.
21+
3. Generate [**PTX**](https://docs.nvidia.com/cuda/parallel-thread-execution/).
22+
23+
24+
### Registering The Plugin Interface
25+
26+
[`trtp.register`](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/infer/tensorrt.plugin/trt_plugin_register.html#tensorrt.plugin.register)
27+
decorates a function that defines the plugin interface:
28+
29+
```py
30+
import tensorrt.plugin as trtp
31+
32+
# Plugin IDs are of the form: "<namespace>::<name>" and
33+
# uniquely identify a plugin.
34+
INCREMENT_PLUGIN_ID = "example::increment"
35+
36+
@trtp.register(INCREMENT_PLUGIN_ID)
37+
def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> trtp.TensorDesc:
38+
"""
39+
Defines the plugin interface - inputs, outputs and attributes.
40+
41+
Args:
42+
inp0: Input tensor descriptor
43+
block_size: Block size for the Triton kernel
44+
45+
Returns:
46+
Output tensor descriptor with same shape/dtype as input
47+
"""
48+
return inp0.like()
49+
```
50+
51+
### Implementing The Kernel
52+
53+
For this example, we use [OpenAI's Triton language](https://triton-lang.org/main/index.html)
54+
to implement the kernel:
55+
56+
```py
57+
import triton
58+
import triton.language as tl
59+
60+
@triton.jit # doc: ignore-line
61+
def increment(x_ptr, num_elements, y_ptr, BLOCK_SIZE: tl.constexpr):
62+
pid = tl.program_id(0)
63+
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
64+
mask = offsets < num_elements
65+
x = tl.load(x_ptr + offsets, mask=mask)
66+
tl.store(y_ptr + offsets, x + 1, mask=mask)
67+
```
68+
69+
<!-- Tripy: DOC: OMIT Start -->
70+
<!-- Hack to make source code inspect work - the decorator tries to inspect the source
71+
code before we have injected it, so we need to invoke it *after* the function definition -->
72+
```py
73+
# doc: no-print-locals
74+
increment.__globals__.update({"tl": tl}) # This is required to make `tl` available during `triton.compile`.
75+
increment = triton.jit(increment)
76+
```
77+
<!-- Tripy: DOC: OMIT End -->
78+
79+
:::{note}
80+
Kernels can be written in many other ways, e.g. CUDA, CUTLASS, Numba, etc. as long as we can emit PTX.
81+
:::
82+
83+
84+
### Retrieving PTX
85+
86+
[`trtp.aot_impl`](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/infer/tensorrt.plugin/trt_plugin_aot_impl/index.html#tensorrt.plugin.aot_impl)
87+
decorates a function that retrieves PTX, launch parameters, and any extra arguments:
88+
89+
```py
90+
from typing import Tuple, Union
91+
import tensorrt.plugin as trtp
92+
93+
@trtp.aot_impl(INCREMENT_PLUGIN_ID)
94+
def increment_aot_impl(
95+
inp0: trtp.TensorDesc, block_size: int, outputs: Tuple[trtp.TensorDesc], tactic: int
96+
) -> Tuple[
97+
Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs
98+
]:
99+
src = triton.compiler.ASTSource(
100+
fn=increment,
101+
signature="*fp32,i32,*fp32",
102+
constants={
103+
"BLOCK_SIZE": block_size,
104+
},
105+
)
106+
107+
compiled_kernel = triton.compile(src)
108+
109+
# Set the grid, block dims and shared memory for the
110+
# kernel (as symbolic expressions)
111+
launch_params = trtp.KernelLaunchParams()
112+
num_elements = inp0.shape_expr.numel()
113+
launch_params.grid_x = trtp.cdiv(num_elements, block_size)
114+
launch_params.block_x = compiled_kernel.metadata.num_warps * 32
115+
launch_params.shared_mem = compiled_kernel.metadata.shared
116+
117+
# Define extra scalar arguments for the
118+
# kernel (as symbolic expressions)
119+
extra_args = trtp.SymIntExprs(1)
120+
extra_args[0] = trtp.SymInt32(num_elements)
121+
122+
return compiled_kernel.metadata.name, compiled_kernel.asm["ptx"], launch_params, extra_args
123+
```
124+
125+
126+
## Using The Plugin
127+
128+
We can use the plugin with {func}`nvtripy.plugin`:
129+
130+
```py
131+
inp = tp.iota((2, 2))
132+
# Plugin attributes are passed as keyword arguments and must match
133+
# the attributes specified by the registration function.
134+
out = tp.plugin(INCREMENT_PLUGIN_ID, [inp], block_size=256)
135+
assert tp.equal(out, inp + 1) # doc: omit
136+
```

tripy/docs/pre0_user_guides/integrating-tensorrt-python-plugins.md

Lines changed: 0 additions & 116 deletions
This file was deleted.

0 commit comments

Comments
 (0)