|
| 1 | +# Custom GPU Kernels via Triton |
| 2 | + |
| 3 | +PyTorch/XLA now supports [Triton](https://openai.com/research/triton) kernels, enabling high-performance deep learning model execution on GPUs. Triton, a specialized language and compiler for GPU programming, empowers developers to write custom kernels that leverage the full potential of GPUs for various operations in deep learning models. |
| 4 | + |
| 5 | +Given a Triton kernel defined as follows: |
| 6 | +```python3 |
| 7 | +@triton.jit |
| 8 | +def add_kernel( |
| 9 | + x_ptr, # *Pointer* to first input vector. |
| 10 | + y_ptr, # *Pointer* to second input vector. |
| 11 | + output_ptr, # *Pointer* to output vector. |
| 12 | + n_elements, # Size of the vector. |
| 13 | + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. |
| 14 | + # NOTE: `constexpr` so it can be used as a shape value. |
| 15 | +): |
| 16 | + # Triton add kernel from https://github.com/openai/triton/blob/main/python/tutorials/01-vector-add.py#L28 |
| 17 | + pid = tl.program_id(axis=0) |
| 18 | + block_start = pid * BLOCK_SIZE |
| 19 | + offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 20 | + mask = offsets < n_elements |
| 21 | + x = tl.load(x_ptr + offsets, mask=mask) |
| 22 | + y = tl.load(y_ptr + offsets, mask=mask) |
| 23 | + output = x + y |
| 24 | + tl.store(output_ptr + offsets, output, mask=mask) |
| 25 | + |
| 26 | +``` |
| 27 | + |
| 28 | +We can run make this kernel a part of the PyTorch/XLA execution graph as follows: |
| 29 | + |
| 30 | +```python3 |
| 31 | +import torch |
| 32 | + |
| 33 | +import torch_xla.experimental.triton as xla_triton |
| 34 | +import torch_xla |
| 35 | + |
| 36 | +import triton |
| 37 | +import triton.language as tl |
| 38 | + |
| 39 | +size = 16 |
| 40 | +x = torch.arange(size, dtype=torch.int64).to("xla") |
| 41 | +y = torch.arange(size, dtype=torch.int64).to("xla") |
| 42 | +output = torch.empty_like(x) |
| 43 | +block_size = 8 |
| 44 | +grid = (triton.cdiv(size, block_size),) |
| 45 | + |
| 46 | +# triton_call takes the same arguments as the triton.jit function, in addition |
| 47 | +to the kernel itself and the grid that is used to execute the kernel. |
| 48 | +All the tl.constexpr terms are passed as kwargs at the end. |
| 49 | +payload = xla_triton.triton_call( |
| 50 | + x, y, output, size, kernel=add_kernel, grid=grid, BLOCK_SIZE=block_size) |
| 51 | + |
| 52 | +# To make the triton kernel, a part of the PyTorch/XLA graph, we create a |
| 53 | +# custom call node with the expected inputs, payload from triton_call, |
| 54 | +# the output shapes and output dtypes. The payload already contains information |
| 55 | +# regarding how the GPU buffers will be loaded when this node is executed. |
| 56 | +output = torch_xla._XLAC._xla_gpu_custom_call([x, y], payload, |
| 57 | + [output.shape], [torch.int64]) |
| 58 | + |
| 59 | +``` |
| 60 | + |
| 61 | +For more complex kernels, you can also refer to the Triton Flash Attention kernel test in PyTorch/XLA. |
| 62 | + |
| 63 | +## Dependencies |
| 64 | +The Triton integration depends on the `triton` package to function. This code is tested with `triton==2.3.0`. To install: |
| 65 | +```bash |
| 66 | +pip install --no-deps triton==2.3.0 |
0 commit comments