Skip to content

Commit dc0acb4

Browse files
authored
Add docs for triton (#7719)
1 parent d39c523 commit dc0acb4

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

docs/triton.md

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Custom GPU Kernels via Triton
2+
3+
PyTorch/XLA now supports [Triton](https://openai.com/research/triton) kernels, enabling high-performance deep learning model execution on GPUs. Triton, a specialized language and compiler for GPU programming, empowers developers to write custom kernels that leverage the full potential of GPUs for various operations in deep learning models.
4+
5+
Given a Triton kernel defined as follows:
6+
```python3
7+
@triton.jit
8+
def add_kernel(
9+
x_ptr, # *Pointer* to first input vector.
10+
y_ptr, # *Pointer* to second input vector.
11+
output_ptr, # *Pointer* to output vector.
12+
n_elements, # Size of the vector.
13+
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
14+
# NOTE: `constexpr` so it can be used as a shape value.
15+
):
16+
# Triton add kernel from https://github.com/openai/triton/blob/main/python/tutorials/01-vector-add.py#L28
17+
pid = tl.program_id(axis=0)
18+
block_start = pid * BLOCK_SIZE
19+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
20+
mask = offsets < n_elements
21+
x = tl.load(x_ptr + offsets, mask=mask)
22+
y = tl.load(y_ptr + offsets, mask=mask)
23+
output = x + y
24+
tl.store(output_ptr + offsets, output, mask=mask)
25+
26+
```
27+
28+
We can run make this kernel a part of the PyTorch/XLA execution graph as follows:
29+
30+
```python3
31+
import torch
32+
33+
import torch_xla.experimental.triton as xla_triton
34+
import torch_xla
35+
36+
import triton
37+
import triton.language as tl
38+
39+
size = 16
40+
x = torch.arange(size, dtype=torch.int64).to("xla")
41+
y = torch.arange(size, dtype=torch.int64).to("xla")
42+
output = torch.empty_like(x)
43+
block_size = 8
44+
grid = (triton.cdiv(size, block_size),)
45+
46+
# triton_call takes the same arguments as the triton.jit function, in addition
47+
to the kernel itself and the grid that is used to execute the kernel.
48+
All the tl.constexpr terms are passed as kwargs at the end.
49+
payload = xla_triton.triton_call(
50+
x, y, output, size, kernel=add_kernel, grid=grid, BLOCK_SIZE=block_size)
51+
52+
# To make the triton kernel, a part of the PyTorch/XLA graph, we create a
53+
# custom call node with the expected inputs, payload from triton_call,
54+
# the output shapes and output dtypes. The payload already contains information
55+
# regarding how the GPU buffers will be loaded when this node is executed.
56+
output = torch_xla._XLAC._xla_gpu_custom_call([x, y], payload,
57+
[output.shape], [torch.int64])
58+
59+
```
60+
61+
For more complex kernels, you can also refer to the Triton Flash Attention kernel test in PyTorch/XLA.
62+
63+
## Dependencies
64+
The Triton integration depends on the `triton` package to function. This code is tested with `triton==2.3.0`. To install:
65+
```bash
66+
pip install --no-deps triton==2.3.0

0 commit comments

Comments
 (0)