Skip to content

Commit 2838b73

Browse files
Refactors quantization guide
- Refactors quantization guide to be more concise - Updates `nvidia-modelopt` version so we don't need a special package index - Reduces threshold for summary mode when pretty printing tensors. - Adds a `# doc: no-output` tag to omit output from code examples.
1 parent e0ca435 commit 2838b73

File tree

11 files changed

+226
-134
lines changed

11 files changed

+226
-134
lines changed

tripy/docs/README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,11 @@ Code blocks in docstrings/guides are **preprocessed**:
154154
155155
- Code is **executed** and any output is displayed in the docs.
156156
157-
If the code throws, doc generation will fail. Use `# doc: allow-exception` to allow exceptions.
157+
- `# doc: allow-exception` allows exceptions to be thrown. By default, they are treated as failures.
158158
159-
- **Note:** `# doc: no-eval` disables execution but this means the code will be **untested**!
159+
- `# doc: no-output` omits output from the docs (but still executes the code).
160+
161+
- `# doc: no-eval` disables execution but this means the code will be **untested**!
160162
161163
- Local variables are also displayed. You can customize this:
162164

tripy/docs/pre0_user_guides/00-introduction-to-tripy.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ Usage:
6363
out = fast_mlp(inp)
6464
```
6565

66-
:::{note}
66+
:::{important}
6767
There are **restrictions** on what can be compiled - see {func}`nvtripy.compile`.
6868
:::
6969

Lines changed: 191 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -1,154 +1,235 @@
11
# Quantization
22

3+
**Quantization** reduces memory and compute requirements by running operations in low precision:
4+
- **Scaling** is required to translate to/from low precision.
5+
- **Scaling factors** are chosen such that they minimize accuracy loss.
6+
- They can be either:
7+
- Loaded into quantization-enabled {class}`nvtripy.Module`s, or
8+
- Used with {func}`nvtripy.quantize`/{func}`nvtripy.dequantize`.
39

10+
:::{seealso}
11+
The
12+
[TensorRT developer guide](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#working-with-int8)
13+
explains quantization in more detail.
14+
:::
415

5-
## Using Quantized Modules
616

7-
Various modules predefined by Tripy support quantization. For example, the {class}`nvtripy.Linear`
8-
module includes two arguments to configure the quantization mode. Let's construct the following
9-
quantized linear module:
17+
## Post-Training Quantization With ModelOpt
1018

11-
```py
12-
# doc: print-locals quant_linear
13-
quant_linear = tp.Linear(
14-
4,
15-
2,
16-
quant_dtype=tp.int8,
17-
weight_quant_dim=None,
18-
)
19-
```
19+
If the model was not trained with quantization-aware training (QAT), we can use
20+
[TensorRT ModelOpt](https://nvidia.github.io/TensorRT-Model-Optimizer/index.html)
21+
to do **calibration** to determine scaling factors.
2022

21-
As described in {class}`nvtripy.Linear`, the quantized linear module has
22-
2 additional parameters compared to a normal linear layer:
23+
:::{admonition} Info
24+
**Calibration** runs a model with a small set of input data to determine the
25+
numerical distribution of each tensor.
2326

24-
1. `weight_scale`: The quantization scale for `weight`.
27+
The **dynamic range** is the most important range within this distribution and
28+
scales are chosen to target this range.
29+
:::
2530

26-
2. `input_scale`: The quantization scale for the input.
31+
Let's calibrate a GPT model:
2732

28-
`weight_scale` must always be provided while `input_scale` is optional. The input will be quantized
29-
only if `input_scale` is provided. For a `Linear` module in this example, only "per-tensor" quantization
30-
is allowed for the input. This is why there is no `input_quant_dim` argument.
33+
1. Install ModelOpt:
3134

32-
Let's fill the scale parameters with dummy data:
35+
```bash
36+
python3 -m pip install nvidia-modelopt==0.11.1 transformers==4.46.2 datasets==2.21.0
37+
```
3338

34-
```py
35-
# doc: print-locals quant_linear
36-
quant_linear.weight_scale = tp.Tensor(1.0)
37-
quant_linear.input_scale = tp.Tensor(1.0)
38-
```
39+
2. Download the model:
3940

40-
and run a forward pass to see the result:
41+
```py
42+
# doc: no-print-locals
43+
from transformers import GPT2LMHeadModel
4144
42-
```py
43-
x = tp.iota((3, 4), dtype=tp.float32)
44-
out = quant_linear(x)
45-
assert tp.equal(out, tp.Tensor([[0.0000, 1.0000], [6.0000, 23.0000], [12.0000, 45.0000]])) # doc: omit
46-
```
45+
model = GPT2LMHeadModel.from_pretrained("gpt2")
46+
```
4747

48-
The result still has a data type of {class}`nvtripy.float32`, but internally, TensorRT quantized the
49-
input and weight, executed the linear layer with {class}`nvtripy.int8` precision, and finally dequantized
50-
the output back to the original precision.
48+
3. Calibrate for `int8` precision:
5149

52-
## Running Quantized Models
50+
1. Define the forward pass:
5351

54-
Now that we have covered how quantization works in {class}`nvtripy.Linear`, we will walk through
55-
the workflow of running a real-world quantized model: [nanoGPT](source:/examples/nanogpt/).
52+
```py
53+
# doc: no-output
54+
from transformers import AutoTokenizer
55+
from modelopt.torch.utils.dataset_utils import create_forward_loop
5656
57-
### Calibration With Model Optimizer
57+
MAX_SEQ_LEN = 512
58+
tokenizer = AutoTokenizer.from_pretrained(
59+
"gpt2",
60+
use_fast=True,
61+
model_max_length=MAX_SEQ_LEN,
62+
padding_side="left",
63+
trust_remote_code=True,
64+
)
65+
tokenizer.pad_token = tokenizer.eos_token
5866
59-
<!-- Tripy: TEST: IGNORE Start -->
67+
forward_loop = create_forward_loop(
68+
model=model,
69+
dataset_name="cnn_dailymail",
70+
tokenizer=tokenizer,
71+
device=model.device,
72+
num_samples=8,
73+
)
74+
```
6075

61-
The quantization scales are not available unless the model was trained with QAT (quantization-aware training).
62-
We need to perform another step called calibration to compute the correct scales for each quantized layer.
63-
There are many ways to do calibration, one of which is using the `nvidia-modelopt` toolkit. To install it, run:
76+
2. Set up quantization configuration:
6477

65-
```bash
66-
python3 -m pip install --extra-index-url https://pypi.nvidia.com nvidia-modelopt==0.11.0 transformers==4.46.2 datasets==2.21.0
67-
```
78+
```py
79+
import modelopt.torch.quantization as mtq
6880
69-
First, let's get the pre-trained GPT model from hugging face:
81+
quant_cfg = mtq.INT8_DEFAULT_CFG
82+
```
7083

71-
```py
72-
# doc: no-print-locals
73-
from transformers import GPT2LMHeadModel
84+
3. Run calibration to replace linear layers with
85+
[`QuantLinear`](https://nvidia.github.io/TensorRT-Model-Optimizer/reference/generated/modelopt.torch.quantization.nn.modules.quant_linear.html#modelopt.torch.quantization.nn.modules.quant_linear.QuantLinear),
86+
which contain calibration information:
87+
88+
```py
89+
# doc: no-output
90+
mtq.quantize(model, quant_cfg, forward_loop=forward_loop)
91+
```
92+
93+
94+
The `amax` attributes of `QuantLinear`'s quantizers specify **dynamic ranges**:
7495
75-
model = GPT2LMHeadModel.from_pretrained("gpt2")
96+
```py
97+
torch_qlinear = model.transformer.h[0].attn.c_attn
98+
print(torch_qlinear)
7699
```
77100
78-
Then, we perform int8 weight-only quantization:
101+
We must convert dynamic ranges to scaling factors to load them into Tripy:
79102
80103
```py
81-
from transformers import AutoTokenizer
82-
import modelopt.torch.quantization as mtq
83-
84-
from modelopt.torch.utils.dataset_utils import create_forward_loop
85-
86-
# define the modelopt quant configs
87-
quant_cfg = mtq.INT8_DEFAULT_CFG
88-
# disable input quantization for weight-only
89-
# quantized linear modules
90-
quant_cfg["quant_cfg"]["*input_quantizer"] = {
91-
"enable": False,
92-
}
93-
94-
# define the forward loop for calibration
95-
MAX_SEQ_LEN = 512
96-
tokenizer = AutoTokenizer.from_pretrained(
97-
"gpt2",
98-
use_fast=True,
99-
model_max_length=MAX_SEQ_LEN,
100-
padding_side="left",
101-
trust_remote_code=True,
102-
)
103-
tokenizer.pad_token = tokenizer.eos_token
104-
105-
forward_loop = create_forward_loop(
106-
model=model,
107-
dataset_name="cnn_dailymail",
108-
tokenizer=tokenizer,
109-
device=model.device,
110-
num_samples=8,
111-
)
112-
113-
# call the api for calibration
114-
mtq.quantize(model, quant_cfg, forward_loop=forward_loop)
104+
def get_scale(quantizer):
105+
amax = quantizer.export_amax()
106+
# `maxbound` is the maximum value representible by the data type.
107+
# For `int8`, this is 127.
108+
scale = amax.float() / quantizer.maxbound
109+
return tp.Tensor(scale.squeeze().contiguous())
110+
111+
input_scale = get_scale(torch_qlinear.input_quantizer)
112+
weight_scale = get_scale(torch_qlinear.weight_quantizer)
115113
```
116114
117-
`mtq.quantize` replaces all linear layers specified in `quant_cfg` with `QuantLinear`
118-
layers, which contain the calibrated parameters.
119115
120-
### Load Scales Into The Tripy Model
116+
## Loading Scales Into Tripy
117+
118+
### Using Modules
119+
120+
Modules that support quantization usually:
121+
- Expose additional model parameters for scales.
122+
- Accept arguments that control how quantization is performed.
121123
122-
Let's take a look at one of the `QuantLinear` produced by model optimizer:
124+
Let's load the scales into an {class}`nvtripy.Linear` module:
123125

124126
```py
125-
print(model.transformer.h[0].attn.c_attn)
127+
qlinear = tp.Linear(
128+
768,
129+
2304,
130+
# The data type to quantize to:
131+
quant_dtype=tp.int8,
132+
# The dimension along which the weights are quantized:
133+
weight_quant_dim=torch_qlinear.weight_quantizer.axis)
134+
135+
# Load weights:
136+
qlinear.weight = tp.Tensor(torch_qlinear.weight.detach().contiguous())
137+
qlinear.bias = tp.Tensor(torch_qlinear.bias.detach().contiguous())
138+
139+
# Load scaling factors:
140+
qlinear.input_scale = input_scale
141+
qlinear.weight_scale = weight_scale
126142
```
127143

128-
The `amax` attribute gives us the dynamic range of the tensor. Tripy requires scaling factors, so we can convert it like so:
144+
:::{note}
145+
We use scales from ModelOpt here, but scaling factors can come from anywhere.
146+
:::
147+
148+
We can run it just like a regular `float32` module.
149+
Inputs/weights are quantized internally:
129150

130151
```py
131-
def convert_to_scale(amax, maxbound):
132-
return amax.float() / maxbound
152+
input = tp.ones((1, 768), dtype=tp.float32)
153+
154+
output = qlinear(input)
133155
```
134156
135-
Let's convert the `amax` to the scaling factor and load it to a compatible {class}`nvtripy.Linear` module:
157+
:::{seealso}
158+
`load_quant_weights_from_hf` in the [nanoGPT weight loader](source:/examples/nanogpt/weight_loader.py)
159+
is an example of loading scaling factors for an entire model.
160+
:::
161+
162+
163+
### Manually
136164
165+
When using {func}`nvtripy.quantize`/{func}`nvtripy.dequantize`,
166+
`dequantize` must **immediately follow** `quantize`.
167+
168+
TensorRT will **rotate** `dequantize` over subsequent ops as needed.
169+
170+
:::{seealso}
171+
The
172+
[TensorRT developer guide](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#qdq-placement-recs)
173+
includes recommendations on placement of quantization and dequantization ops.
174+
:::
175+
176+
<!-- We cannot print the quantized input/weight below since that would break Q/DQ rotation -->
177+
178+
To mimic the behavior of the {class}`nvtripy.Linear` module above, we can:
179+
180+
1. Quantize the input:
181+
182+
```py
183+
# doc: no-print-locals
184+
input = tp.ones((1, 768), dtype=tp.float32)
185+
186+
input = tp.quantize(input, input_scale, dtype=tp.int8)
187+
# Note the placement of dequantize:
188+
input = tp.dequantize(input, input_scale, dtype=tp.float32)
189+
```
190+
191+
2. Quantize the weights:
192+
193+
```py
194+
# doc: no-print-locals
195+
weight = tp.Tensor(torch_qlinear.weight.detach().contiguous())
196+
197+
dim = torch_qlinear.weight_quantizer.axis
198+
weight = tp.quantize(weight, weight_scale, dtype=tp.int8, dim=dim)
199+
weight = tp.dequantize(weight, weight_scale, dtype=tp.float32, dim=dim)
200+
```
201+
202+
3. Perform the computation (matrix multiply in this case):
203+
204+
```py
205+
# doc: no-print-locals bias
206+
bias = tp.Tensor(torch_qlinear.bias.detach().contiguous())
207+
208+
output = input @ tp.transpose(weight, 0, 1) + bias
209+
```
210+
211+
:::{warning}
212+
**Evaluating** the tensor produced by `dequantize` will affect accuracy.
213+
214+
- **Why:** Evaluation replaces the tensor with a constant, losing information
215+
like which op produced it.
216+
217+
So, TensorRT won't see `dequantize` when evaluating subsequent ops and
218+
won't **rotate** it correctly.
219+
220+
For example, **don't** do this:
137221
```py
138-
# doc: print-locals weight_only_qlinear
139-
weight_only_qlinear = tp.Linear(
140-
768,
141-
2304,
142-
quant_dtype=tp.int8,
143-
weight_quant_dim=0,
144-
)
145-
quantizer = model.transformer.h[0].attn.c_attn.weight_quantizer
146-
scale = convert_to_scale(quantizer.export_amax(), quantizer.maxbound)
147-
scale = scale.squeeze().contiguous()
148-
weight_only_qlinear.weight_scale = tp.Tensor(scale)
149-
```
222+
# doc: no-eval
223+
tensor = tp.ones(...)
150224
151-
For an example of how to load weights from a quantized model, refer to
152-
[load_quant_weights_from_hf](source:/examples/nanogpt/weight_loader.py) from the nanoGPT example.
225+
tensor = tp.quantize(tensor, ...)
226+
tensor = tp.dequantize(tensor, ...)
153227
154-
<!-- Tripy: TEST: IGNORE End -->
228+
# The `print` below will trigger an evaluation of the tensor which will prevent
229+
# TensorRT from rotating the dequantization node. This will affect accuracy!
230+
print(tensor)
231+
232+
# Rest of the program, including some computation involving tensor
233+
...
234+
```
235+
:::

tripy/docs/pre0_user_guides/02-compiler.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
Modules and functions can be compiled for better performance.
44

5-
:::{note}
5+
:::{important}
66
There are **restrictions** on what can be compiled - see {func}`nvtripy.compile`.
77
:::
88

tripy/examples/nanogpt/requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,5 @@ transformers==4.46.2
44
tiktoken==0.5.2
55
--extra-index-url https://download.pytorch.org/whl/cu121
66
torch==2.3.0
7-
--extra-index-url https://pypi.nvidia.com
8-
nvidia-modelopt==0.11.0
7+
nvidia-modelopt==0.11.1
98
datasets==2.18.0

0 commit comments

Comments
 (0)