Skip to content

Add pt2e tutorials to torchao doc page #2384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ for an overall introduction to the library and recent highlight and updates.
:caption: Getting Started

quick_start
pt2e_quant

.. toctree::
:glob:
Expand All @@ -35,11 +36,23 @@ for an overall introduction to the library and recent highlight and updates.
.. toctree::
:glob:
:maxdepth: 1
:caption: Tutorials
:caption: Eager Quantization Tutorials

serialization
subclass_basic
subclass_advanced
static_quantization
pretraining
torchao_vllm_integration

.. toctree::
:glob:
:maxdepth: 1
:caption: PT2E Quantization Tutorials

tutorials_source/pt2e_quant_ptq
tutorials_source/pt2e_quant_qat
tutorials_source/pt2e_quant_x86_inductor
tutorials_source/pt2e_quant_xpu_inductor
tutorials_source/pt2e_quantizer
tutorials_source/openvino_quantizer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to group these somehow? Maybe a new "PT2E Tutorials" section or something? Right now it looks kind of cluttered and makes other torchao tutorials look small

Screenshot 2025-06-16 at 4 50 05 PM

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also what do you think about renaming these tutorials to make the titles look more consistent, e.g.

PyTorch 2 Export: Post-Training Quantization (PTQ) (prototype)
PyTorch 2 Export: Quantization-Aware Training (QAT) (prototype)
PyTorch 2 Export: Quantization with X86 Backend through Inductor
PyTorch 2 Export: Quantization with Intel GPU Backend through Inductor
PyTorch 2 Export: How to Write Your Own Quantizer
PyTorch 2 Export: Quantization for OpenVINO torch.compile Backend

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If these are all under a "PyTorch 2 Export Tutorials" section I feel we can even just drop the prefix in front of each one of these (but keep them on the actual tutorial page if that's possible, not sure how though)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, updated

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have put these in a separate category

95 changes: 88 additions & 7 deletions docs/source/quick_start.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,20 @@ First, let's set up our toy model:

import copy
import torch

class ToyLinearModel(torch.nn.Module):
def __init__(self, m: int, n: int, k: int):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=False)
self.linear2 = torch.nn.Linear(n, k, bias=False)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x

model = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")

# Optional: compile model for faster inference and generation
model = torch.compile(model, mode="max-autotune", fullgraph=True)
model_bf16 = copy.deepcopy(model)
Expand Down Expand Up @@ -99,18 +99,18 @@ it is also much faster!
benchmark_model,
unwrap_tensor_subclass,
)

# Temporary workaround for tensor subclass + torch.compile
# Only needed for torch version < 2.5
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(model)

num_runs = 100
torch._dynamo.reset()
example_inputs = (torch.randn(1, 1024, dtype=torch.bfloat16, device="cuda"),)
bf16_time = benchmark_model(model_bf16, num_runs, example_inputs)
int4_time = benchmark_model(model, num_runs, example_inputs)

print("bf16 mean time: %0.3f ms" % bf16_time)
print("int4 mean time: %0.3f ms" % int4_time)
print("speedup: %0.1fx" % (bf16_time / int4_time))
Expand All @@ -121,6 +121,87 @@ On a single A100 GPU with 80GB memory, this prints::
int4 mean time: 4.410 ms
speedup: 6.9x

PyTorch 2 Export Quantization
=============================
PyTorch 2 Export Quantization is a full graph quantization workflow mostly for static quantization. It targets hardwares that requires both input and output activation and weight to be quantized and relies of recognizing an operator pattern to make quantization decisions (such as linear - relu). PT2E quantization produces a pattern with quantize and dequantize ops inserted around the operators and during lowering quantized operator patterns will be fused into real quantized ops. Currently there are two typical lowering paths, 1. torch.compile through inductor lowering 2. ExecuTorch through delegation

Here we show an example with X86InductorQuantizer

API Example::

import torch
from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e
from torch.export import export
from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import (
X86InductorQuantizer,
get_default_x86_inductor_quantization_config,
)

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 10)

def forward(self, x):
return self.linear(x)

# initialize a floating point model
float_model = M().eval()

# define calibration function
def calibrate(model, data_loader):
model.eval()
with torch.no_grad():
for image, target in data_loader:
model(image)

# Step 1. program capture
m = export(m, *example_inputs).module()
# we get a model with aten ops

# Step 2. quantization
# backend developer will write their own Quantizer and expose methods to allow
# users to express how they
# want the model to be quantized
quantizer = X86InductorQuantizer()
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())

# or prepare_qat_pt2e for Quantization Aware Training
m = prepare_pt2e(m, quantizer)

# run calibration
# calibrate(m, sample_inference_data)
m = convert_pt2e(m)

# Step 3. lowering
# lower to target backend

# Optional: using the C++ wrapper instead of default Python wrapper
import torch._inductor.config as config
config.cpp_wrapper = True

with torch.no_grad():
optimized_model = torch.compile(converted_model)

# Running some benchmark
optimized_model(*example_inputs)


Please follow these tutorials to get started on PyTorch 2 Export Quantization:

Modeling Users:

- `PyTorch 2 Export Post Training Quantization <https://docs.pytorch.org/ao/stable/tutorial_source/pt2e_quant_ptq.html>`_
- `PyTorch 2 Export Quantization Aware Training <ttps://docs.pytorch.org/ao/stable/tutorial_source/pt2e_quant_qat.html>`_
- `PyTorch 2 Export Post Training Quantization with X86 Backend through Inductor <https://docs.pytorch.org/ao/stable/tutorial_source/pt2e_quant_x86_inductor.html>`_
- `PyTorch 2 Export Post Training Quantization with XPU Backend through Inductor <https://docs.pytorch.org/ao/stable/tutorial_source/pt2e_quant_xpu_inductor.html>`_
- `PyTorch 2 Export Quantization for OpenVINO torch.compile Backend <https://docs.pytorch.org/ao/stable/tutorial_source/pt2e_quant_openvino.html>`_


Backend Developers (please check out all Modeling Users docs as well):

- `How to Write a Quantizer for PyTorch 2 Export Quantization <https://docs.pytorch.org/ao/stable/tutorial_source/pt2e_quantizer.html>`_


Next Steps
==========
Expand Down
Loading
Loading