Skip to content

Commit 3079c21

Browse files
authored
Add docs for AI decoder training with PyTorch (#344)
## Add TensorRT Decoder Training Tutorial ### Overview This PR introduces comprehensive documentation for training and deploying neural network-based quantum error correction decoders using the TensorRT decoder plugin (`trt_decoder`), which is being released with CUDA-Q QEC v0.5.0. ### Motivation With the release of the TensorRT decoder, users need clear guidance on: - How to generate training data for QEC decoding tasks - How to train custom neural network decoders - How to export models to ONNX format - How to deploy trained models with TensorRT for accelerated GPU inference This tutorial provides an end-to-end workflow demonstrating the complete pipeline from data generation to production deployment. ### Changes Made #### New Documentation - **`training_ai_decoders.rst`**: Comprehensive tutorial covering: - Training neural network decoders with PyTorch and Stim - Surface code circuit generation and data sampling - MLP architecture design and training workflow - ONNX model export - TensorRT deployment with Python and C++ examples - Converting ONNX models to TensorRT engines using `trtexec` - Performance tuning and best practices #### Training Script - **`train_mlp_decoder.py`**: Complete working example demonstrating: - Stim-based synthetic data generation for surface codes - PyTorch MLP decoder implementation - Training loop with validation and early stopping - ONNX export for TensorRT deployment - Moved from unittest directory to `docs/sphinx/examples/qec/python/` #### Documentation Updates - **`examples.rst`**: Added new tutorial to QEC examples table of contents - **`decoders.rst`**: Minor formatting cleanup ### Tutorial Features - ✅ Complete end-to-end workflow from training to deployment - ✅ Both Python and C++ usage examples - ✅ Multiple precision modes (fp16, fp32, bf16, int8, fp8, tf32, best) - ✅ Production deployment guidance with pre-built TensorRT engines - ✅ Best practices and performance tuning tips - ✅ Clear dependency requirements ### Target Audience - Users training custom neural network decoders for QEC - Researchers experimenting with ML-based decoding approaches - Developers deploying production QEC systems with TensorRT acceleration ### Testing The training script has been tested and successfully: - Generates training data using Stim - Trains an MLP decoder on surface code syndromes - Exports models to ONNX format compatible with TensorRT - Achieves convergence with validation accuracy monitoring --- **Related Components:** - TensorRT decoder plugin (PR #307) - CUDA-Q QEC library decoder interface
1 parent ceb2923 commit 3079c21

File tree

7 files changed

+233
-20
lines changed

7 files changed

+233
-20
lines changed

.github/workflows/all_libs.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ jobs:
109109
# Install the correct torch first.
110110
cuda_no_dot=$(echo ${{ matrix.cuda_version }} | sed 's/\.//')
111111
pip install torch==2.9.0 --index-url https://download.pytorch.org/whl/cu${cuda_no_dot}
112-
pip install numpy pytest cupy-cuda${{ steps.config.outputs.cuda_major }}x cuquantum-cu${{ steps.config.outputs.cuda_major }} lightning ml_collections mpi4py transformers quimb opt_einsum nvidia-cublas cuquantum-python-cu${{ steps.config.outputs.cuda_major }}==25.09.1
112+
pip install numpy pytest onnxscript cupy-cuda${{ steps.config.outputs.cuda_major }}x cuquantum-cu${{ steps.config.outputs.cuda_major }} lightning ml_collections mpi4py transformers quimb opt_einsum nvidia-cublas cuquantum-python-cu${{ steps.config.outputs.cuda_major }}==25.09.1
113113
# The following tests are needed for docs/sphinx/examples/qec/python/tensor_network_decoder.py.
114114
if [ "$(uname -m)" == "x86_64" ]; then
115115
# Stim is not currently available on manylinux ARM wheels, so only

.github/workflows/all_libs_release.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ jobs:
133133
# Install the correct torch first.
134134
cuda_no_dot=$(echo ${{ matrix.cuda_version }} | sed 's/\.//')
135135
pip install torch==2.9.0 --index-url https://download.pytorch.org/whl/cu${cuda_no_dot}
136-
pip install numpy pytest cupy-cuda${{ steps.config.outputs.cuda_major }}x cuquantum-cu${{ steps.config.outputs.cuda_major }} lightning ml_collections mpi4py transformers quimb opt_einsum nvidia-cublas cuquantum-python-cu${{ steps.config.outputs.cuda_major }}==25.09.1
136+
pip install numpy pytest onnxscript cupy-cuda${{ steps.config.outputs.cuda_major }}x cuquantum-cu${{ steps.config.outputs.cuda_major }} lightning ml_collections mpi4py transformers quimb opt_einsum nvidia-cublas cuquantum-python-cu${{ steps.config.outputs.cuda_major }}==25.09.1
137137
# The following tests are needed for docs/sphinx/examples/qec/python/tensor_network_decoder.py.
138138
if [ "$(uname -m)" == "x86_64" ]; then
139139
# Stim is not currently available on manylinux ARM wheels, so only

.github/workflows/lib_qec.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ jobs:
106106
# Install the correct torch first.
107107
cuda_no_dot=$(echo ${{ matrix.cuda_version }} | sed 's/\.//')
108108
pip install torch==2.9.0 --index-url https://download.pytorch.org/whl/cu${cuda_no_dot}
109-
pip install numpy pytest cupy-cuda${{ steps.config.outputs.cuda_major }}x cuquantum-cu${{ steps.config.outputs.cuda_major }} quimb opt_einsum nvidia-cublas cuquantum-python-cu${{ steps.config.outputs.cuda_major }}==25.09.1
109+
pip install numpy pytest onnxscript cupy-cuda${{ steps.config.outputs.cuda_major }}x cuquantum-cu${{ steps.config.outputs.cuda_major }} quimb opt_einsum nvidia-cublas cuquantum-python-cu${{ steps.config.outputs.cuda_major }}==25.09.1
110110
# The following tests are needed for docs/sphinx/examples/qec/python/tensor_network_decoder.py.
111111
if [ "$(uname -m)" == "x86_64" ]; then
112112
# Stim is not currently available on manylinux ARM wheels, so only

.github/workflows/lib_solvers.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ jobs:
9494
# Install the correct torch first.
9595
cuda_no_dot=$(echo ${{ matrix.cuda_version }} | sed 's/\.//')
9696
pip install torch==2.9.0 --index-url https://download.pytorch.org/whl/cu${cuda_no_dot}
97-
pip install numpy pytest cupy-cuda${{ steps.config.outputs.cuda_major }}x cuquantum-cu${{ steps.config.outputs.cuda_major }} lightning ml_collections mpi4py transformers pytest
97+
pip install numpy pytest onnxscript cupy-cuda${{ steps.config.outputs.cuda_major }}x cuquantum-cu${{ steps.config.outputs.cuda_major }} lightning ml_collections mpi4py transformers pytest
9898
9999
100100
- name: Run Python tests

libs/qec/unittests/decoders/trt_decoder/train_mlp_decoder.py renamed to docs/sphinx/examples/qec/python/train_mlp_decoder.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
# ============================================================================ #
2+
# Copyright (c) 2025 NVIDIA Corporation & Affiliates. #
3+
# All rights reserved. #
4+
# #
5+
# This source code and the accompanying materials are made available under #
6+
# the terms of the Apache License 2.0 which accompanies this distribution. #
7+
# ============================================================================ #
8+
# [Begin Documentation]
9+
10+
import sys
11+
import platform
12+
if platform.machine().lower() in ("arm64", "aarch64"):
13+
print(
14+
"Warning: stim is not supported on manylinux ARM64/aarch64. Skipping this example..."
15+
)
16+
sys.exit(0)
17+
118
import stim
219
import torch
320
import torch.nn as nn
@@ -13,7 +30,7 @@
1330
num_val_samples = 1000 # Validation samples
1431
num_test_samples = 1000 # Test samples
1532
hidden_dim = 128 # Larger model capacity
16-
error_prob = 0.18 # Balanced error rate for better learning
33+
error_prob = 0.005 # Balanced error rate for better learning
1734

1835
# --------------------------
1936
# Build the surface code circuit
@@ -30,7 +47,7 @@
3047
# Convert to detector error model
3148
dem = circuit.detector_error_model()
3249
num_detectors = dem.num_detectors
33-
num_data_qubits = circuit.num_qubits - num_detectors # approx
50+
num_data_qubits = circuit.num_qubits - num_detectors
3451

3552
print(f"Num data qubits: {num_data_qubits}, Num detectors: {num_detectors}")
3653

@@ -69,9 +86,6 @@ def sample_data(num_samples):
6986
num_observables = Y_train.shape[1]
7087
print(f"Num observables: {num_observables}")
7188

72-
print(f"X_test: {X_test}")
73-
print(f"Y_test: {Y_test}")
74-
7589

7690
# --------------------------
7791
# Improved Torch NN decoder with dropout and deeper architecture
@@ -178,8 +192,8 @@ def compute_accuracy(predictions, targets, threshold=0.5):
178192
val_correct += ((val_output > 0.5).float() == batch_Y).sum().item()
179193
val_total += batch_Y.numel()
180194

181-
print(f"logical_error_rate (raw): {batch_Y.sum().item() / batch_Y.numel()}")
182-
print(f"cum_ler: {cum_ler / len(val_loader.dataset)}")
195+
# print(f"logical_error_rate (raw): {batch_Y.sum().item() / batch_Y.numel()}")
196+
# print(f"cum_ler: {cum_ler / len(val_loader.dataset)}")
183197

184198
val_loss_avg = val_loss_total / len(val_loader.dataset)
185199
val_acc = val_correct / val_total

docs/sphinx/examples_rst/qec/decoders.rst

Lines changed: 182 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,4 +134,185 @@ The decoder returns the probability that the logical observable has flipped for
134134

135135
See Also:
136136

137-
- ``cudaq_qec.plugins.decoders.tensor_network_decoder``
137+
- ``cudaq_qec.plugins.decoders.tensor_network_decoder``
138+
139+
.. _deploying-ai-decoders:
140+
141+
Deploying AI Decoders with TensorRT
142+
+++++++++++++++++++++++++++++++++++++++++++++++++
143+
144+
Starting with CUDA-Q QEC v0.5.0, a GPU-accelerated TensorRT-based decoder is included with the
145+
CUDA-Q QEC library. The TensorRT decoder (``trt_decoder``) enables users to leverage custom AI
146+
models for quantum error correction, providing a flexible framework for deploying trained models
147+
with optimized inference performance on NVIDIA GPUs.
148+
149+
Unlike traditional algorithmic decoders, neural network decoders can be trained on specific error
150+
models and code structures, potentially achieving superior performance for certain noise regimes.
151+
The TensorRT decoder supports loading models in ONNX format and provides configurable precision
152+
modes (fp16, bf16, int8, fp8, tf32) to balance accuracy and inference speed.
153+
154+
This tutorial demonstrates the complete workflow for training a simple multi-layer perceptron (MLP)
155+
to decode surface code syndromes using PyTorch and Stim, exporting the model to ONNX format, and
156+
deploying it with the TensorRT decoder for accelerated inference.
157+
158+
Overview of the Training-to-Deployment Pipeline
159+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
160+
161+
The workflow consists of three main stages:
162+
163+
1. **Data Generation**: Use Stim to generate synthetic quantum error correction data by simulating
164+
surface code circuits with realistic noise models. This produces detector measurements (syndromes)
165+
and observable flips (logical errors) that serve as training data.
166+
167+
2. **Model Training**: Train a neural network (in this case, an MLP) using PyTorch to learn the
168+
mapping from syndromes to logical error predictions. The model is trained with standard deep
169+
learning techniques including dropout regularization, learning rate scheduling, and validation monitoring.
170+
171+
3. **ONNX Export and Deployment**: Export the trained PyTorch model to ONNX format, which can then
172+
be loaded by the TensorRT decoder for optimized GPU inference in production QEC workflows.
173+
174+
Training a Neural Network Decoder with PyTorch and Stim
175+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
176+
177+
The following example shows how to generate training data using Stim's built-in surface code
178+
generator, train an MLP decoder with PyTorch, and export the model to ONNX format.
179+
For instructions on installing PyTorch, see :ref:`Installing PyTorch <installing-pytorch>`.
180+
181+
.. literalinclude:: ../../examples/qec/python/train_mlp_decoder.py
182+
:language: python
183+
:start-after: [Begin Documentation]
184+
185+
Using the TensorRT Decoder in CUDA-Q QEC
186+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
187+
188+
Once you have a trained ONNX model, you can load it with the TensorRT decoder for accelerated
189+
inference. The decoder can be used in both C++ and Python workflows.
190+
191+
**Loading from ONNX (with automatic TensorRT optimization)**:
192+
193+
.. tab:: Python
194+
195+
.. code-block:: python
196+
197+
import cudaq_qec as qec
198+
import numpy as np
199+
200+
# Note: The AI decoder doesn't use the parity check matrix.
201+
# A placeholder matrix is provided here to satisfy the API.
202+
H = np.array([[1, 0, 0, 1, 0, 1, 1],
203+
[0, 1, 0, 1, 1, 0, 1],
204+
[0, 0, 1, 0, 1, 1, 1]], dtype=np.uint8)
205+
206+
# Create TensorRT decoder from ONNX model
207+
decoder = qec.get_decoder("trt_decoder", H,
208+
onnx_load_path="ai_decoder.onnx")
209+
210+
# Decode a syndrome
211+
syndrome = np.array([1.0, 0.0, 1.0], dtype=np.float32)
212+
result = decoder.decode(syndrome)
213+
print(f"Predicted error: {result}")
214+
215+
.. tab:: C++
216+
217+
.. code-block:: cpp
218+
219+
#include "cudaq/qec/decoder.h"
220+
#include "cuda-qx/core/tensor.h"
221+
#include "cuda-qx/core/heterogeneous_map.h"
222+
223+
int main() {
224+
// Note: The AI decoder doesn't use the parity check matrix.
225+
// A placeholder matrix is provided here to satisfy the API.
226+
std::vector<std::vector<uint8_t>> H_vec = {
227+
{1, 0, 0, 1, 0, 1, 1},
228+
{0, 1, 0, 1, 1, 0, 1},
229+
{0, 0, 1, 0, 1, 1, 1}
230+
};
231+
232+
// Convert to tensor
233+
cudaqx::tensor<uint8_t> H({3, 7});
234+
for (size_t i = 0; i < 3; ++i) {
235+
for (size_t j = 0; j < 7; ++j) {
236+
H.at({i, j}) = H_vec[i][j];
237+
}
238+
}
239+
240+
// Create decoder parameters
241+
cudaqx::heterogeneous_map params;
242+
params.insert("onnx_load_path", "ai_decoder.onnx");
243+
params.insert("precision", "fp16");
244+
245+
// Create TensorRT decoder
246+
auto decoder = cudaq::qec::get_decoder("trt_decoder", H, params);
247+
248+
// Decode syndrome
249+
std::vector<cudaq::qec::float_t> syndrome = {1.0, 0.0, 1.0};
250+
auto result = decoder->decode(syndrome);
251+
252+
return 0;
253+
}
254+
255+
**Loading a pre-built TensorRT engine (for fastest initialization)**:
256+
257+
If you've already converted your ONNX model to a TensorRT engine using the provided utility script,
258+
you can load it directly:
259+
260+
.. tab:: Python
261+
262+
.. code-block:: python
263+
264+
decoder = qec.get_decoder("trt_decoder", H,
265+
engine_load_path="surface_code_decoder.trt")
266+
267+
Converting ONNX Models to TensorRT Engines
268+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
269+
270+
For production deployments where initialization time is critical, you can pre-build a TensorRT
271+
engine from your ONNX model using the ``trtexec`` command-line tool that comes with TensorRT:
272+
273+
.. code-block:: bash
274+
275+
# Build with FP16 precision
276+
trtexec --onnx=surface_code_decoder.onnx \
277+
--saveEngine=surface_code_decoder.trt \
278+
--fp16
279+
280+
# Build with best precision for your GPU
281+
trtexec --onnx=surface_code_decoder.onnx \
282+
--saveEngine=surface_code_decoder.trt \
283+
--best
284+
285+
# Build with specific input shape (optional, for optimization)
286+
trtexec --onnx=surface_code_decoder.onnx \
287+
--saveEngine=surface_code_decoder.trt \
288+
--fp16 \
289+
--shapes=detectors:1x24
290+
291+
Pre-built engines offer several advantages:
292+
293+
- **Faster initialization**: Engine loading is significantly faster than ONNX parsing and optimization
294+
- **Reproducible optimization**: The same optimization decisions are made every time
295+
- **Version control**: Engines can be versioned alongside code for reproducible deployments
296+
297+
298+
Dependencies and Requirements
299+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
300+
301+
The TensorRT decoder requires:
302+
303+
- **TensorRT**: Version 10.13.3.9 or higher
304+
- **CUDA**: Version 12.0 or higher for x86 and 13.0 for ARM.
305+
- **GPU**: NVIDIA GPU with compute capability 6.0+ (Pascal architecture or newer)
306+
307+
For training:
308+
309+
- **PyTorch**: Version 2.0+ recommended
310+
- **Stim**: For quantum circuit simulation and data generation
311+
312+
See Also
313+
^^^^^^^^
314+
315+
- :class:`cudaq_qec.Decoder` - Base decoder interface
316+
- `ONNX <https://onnx.ai/>`_ - Open Neural Network Exchange format
317+
- `TensorRT Documentation <https://docs.nvidia.com/deeplearning/tensorrt/>`_ - NVIDIA TensorRT
318+
- `Stim Documentation <https://github.com/quantumlib/Stim>`_ - Fast stabilizer circuit simulator

docs/sphinx/quickstart/installation.rst

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,32 @@ Building from Source
7676
The instructions for building CUDA-QX from source are maintained on our GitHub
7777
repository: `Building CUDA-QX from Source <https://github.com/NVIDIA/cudaqx/blob/main/Building.md>`__.
7878

79-
Known Blackwell Issues
80-
----------------------
81-
.. note::
82-
If you are attempting to use torch on Blackwell, you will need to install the nightly version of torch.
83-
You can do this by running:
79+
.. _installing-pytorch:
80+
81+
Installing PyTorch
82+
------------------
8483

85-
.. code-block:: bash
84+
PyTorch (``torch``) is required for several CUDA-QX features:
8685

87-
python3 -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128
86+
* **Tensor Network Decoder**: Used by the QEC library for tensor network-based decoding (CPU version of PyTorch is sufficient)
87+
* **GQE Algorithm**: Used by the Solvers library for the Generative Quantum Eigensolver
88+
* **Training AI Decoders**: Optionally used for training custom neural network decoders (see :ref:`Deploying AI Decoders with TensorRT <deploying-ai-decoders>`)
8889

89-
torch is a dependency of the tensor network decoder and the GQE algorithm.
90+
PyTorch is automatically installed when you install the optional components:
91+
92+
.. code-block:: bash
93+
94+
# Installs PyTorch as a dependency
95+
pip install cudaq-qec[tensor-network-decoder]
96+
pip install cudaq-solvers[gqe]
97+
98+
Alternatively, you can install PyTorch directly. For detailed installation instructions, visit the
99+
`PyTorch installation page <https://pytorch.org/get-started/locally/>`_.
100+
101+
.. code-block:: bash
102+
103+
pip install torch
104+
105+
.. note::
106+
Users with NVIDIA Blackwell architecture GPUs require PyTorch with CUDA 12.8 or later support.
107+
When installing PyTorch, make sure to select the appropriate CUDA version for your system.

0 commit comments

Comments
 (0)