|
1 | 1 | # Quantization |
2 | 2 |
|
| 3 | +**Quantization** reduces memory and compute requirements by running operations in low precision: |
| 4 | +- **Scaling** is required to translate to/from low precision. |
| 5 | +- **Scaling factors** are chosen such that they minimize accuracy loss. |
| 6 | +- They can be either: |
| 7 | + - Loaded into quantization-enabled {class}`nvtripy.Module`s, or |
| 8 | + - Used with {func}`nvtripy.quantize`/{func}`nvtripy.dequantize`. |
3 | 9 |
|
| 10 | +:::{seealso} |
| 11 | +The |
| 12 | +[TensorRT developer guide](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#working-with-int8) |
| 13 | +explains quantization in more detail. |
| 14 | +::: |
4 | 15 |
|
5 | | -## Using Quantized Modules |
6 | 16 |
|
7 | | -Various modules predefined by Tripy support quantization. For example, the {class}`nvtripy.Linear` |
8 | | -module includes two arguments to configure the quantization mode. Let's construct the following |
9 | | -quantized linear module: |
| 17 | +## Post-Training Quantization With ModelOpt |
10 | 18 |
|
11 | | -```py |
12 | | -# doc: print-locals quant_linear |
13 | | -quant_linear = tp.Linear( |
14 | | - 4, |
15 | | - 2, |
16 | | - quant_dtype=tp.int8, |
17 | | - weight_quant_dim=None, |
18 | | -) |
19 | | -``` |
| 19 | +If the model was not trained with quantization-aware training (QAT), we can use |
| 20 | +[TensorRT ModelOpt](https://nvidia.github.io/TensorRT-Model-Optimizer/index.html) |
| 21 | +to do **calibration** to determine scaling factors. |
20 | 22 |
|
21 | | -As described in {class}`nvtripy.Linear`, the quantized linear module has |
22 | | -2 additional parameters compared to a normal linear layer: |
| 23 | +:::{admonition} Info |
| 24 | +**Calibration** runs a model with a small set of input data to determine the |
| 25 | +numerical distribution of each tensor. |
23 | 26 |
|
24 | | -1. `weight_scale`: The quantization scale for `weight`. |
| 27 | +The **dynamic range** is the most important range within this distribution and |
| 28 | +scales are chosen to target this range. |
| 29 | +::: |
25 | 30 |
|
26 | | -2. `input_scale`: The quantization scale for the input. |
| 31 | +Let's calibrate a GPT model: |
27 | 32 |
|
28 | | -`weight_scale` must always be provided while `input_scale` is optional. The input will be quantized |
29 | | -only if `input_scale` is provided. For a `Linear` module in this example, only "per-tensor" quantization |
30 | | -is allowed for the input. This is why there is no `input_quant_dim` argument. |
| 33 | +1. Install ModelOpt: |
31 | 34 |
|
32 | | -Let's fill the scale parameters with dummy data: |
| 35 | + ```bash |
| 36 | + python3 -m pip install nvidia-modelopt==0.11.1 transformers==4.46.2 datasets==2.21.0 |
| 37 | + ``` |
33 | 38 |
|
34 | | -```py |
35 | | -# doc: print-locals quant_linear |
36 | | -quant_linear.weight_scale = tp.Tensor(1.0) |
37 | | -quant_linear.input_scale = tp.Tensor(1.0) |
38 | | -``` |
| 39 | +2. Download the model: |
39 | 40 |
|
40 | | -and run a forward pass to see the result: |
| 41 | + ```py |
| 42 | + # doc: no-print-locals |
| 43 | + from transformers import GPT2LMHeadModel |
41 | 44 |
|
42 | | -```py |
43 | | -x = tp.iota((3, 4), dtype=tp.float32) |
44 | | -out = quant_linear(x) |
45 | | -assert tp.equal(out, tp.Tensor([[0.0000, 1.0000], [6.0000, 23.0000], [12.0000, 45.0000]])) # doc: omit |
46 | | -``` |
| 45 | + model = GPT2LMHeadModel.from_pretrained("gpt2") |
| 46 | + ``` |
47 | 47 |
|
48 | | -The result still has a data type of {class}`nvtripy.float32`, but internally, TensorRT quantized the |
49 | | -input and weight, executed the linear layer with {class}`nvtripy.int8` precision, and finally dequantized |
50 | | -the output back to the original precision. |
| 48 | +3. Calibrate for `int8` precision: |
51 | 49 |
|
52 | | -## Running Quantized Models |
| 50 | + 1. Define the forward pass: |
53 | 51 |
|
54 | | -Now that we have covered how quantization works in {class}`nvtripy.Linear`, we will walk through |
55 | | -the workflow of running a real-world quantized model: [nanoGPT](source:/examples/nanogpt/). |
| 52 | + ```py |
| 53 | + # doc: no-output |
| 54 | + from transformers import AutoTokenizer |
| 55 | + from modelopt.torch.utils.dataset_utils import create_forward_loop |
56 | 56 |
|
57 | | -### Calibration With Model Optimizer |
| 57 | + MAX_SEQ_LEN = 512 |
| 58 | + tokenizer = AutoTokenizer.from_pretrained( |
| 59 | + "gpt2", |
| 60 | + use_fast=True, |
| 61 | + model_max_length=MAX_SEQ_LEN, |
| 62 | + padding_side="left", |
| 63 | + trust_remote_code=True, |
| 64 | + ) |
| 65 | + tokenizer.pad_token = tokenizer.eos_token |
58 | 66 |
|
59 | | -<!-- Tripy: TEST: IGNORE Start --> |
| 67 | + forward_loop = create_forward_loop( |
| 68 | + model=model, |
| 69 | + dataset_name="cnn_dailymail", |
| 70 | + tokenizer=tokenizer, |
| 71 | + device=model.device, |
| 72 | + num_samples=8, |
| 73 | + ) |
| 74 | + ``` |
60 | 75 |
|
61 | | -The quantization scales are not available unless the model was trained with QAT (quantization-aware training). |
62 | | -We need to perform another step called calibration to compute the correct scales for each quantized layer. |
63 | | -There are many ways to do calibration, one of which is using the `nvidia-modelopt` toolkit. To install it, run: |
| 76 | + 2. Set up quantization configuration: |
64 | 77 |
|
65 | | -```bash |
66 | | -python3 -m pip install --extra-index-url https://pypi.nvidia.com nvidia-modelopt==0.11.0 transformers==4.46.2 datasets==2.21.0 |
67 | | -``` |
| 78 | + ```py |
| 79 | + import modelopt.torch.quantization as mtq |
68 | 80 |
|
69 | | -First, let's get the pre-trained GPT model from hugging face: |
| 81 | + quant_cfg = mtq.INT8_DEFAULT_CFG |
| 82 | + ``` |
70 | 83 |
|
71 | | -```py |
72 | | -# doc: no-print-locals |
73 | | -from transformers import GPT2LMHeadModel |
| 84 | + 3. Run calibration to replace linear layers with |
| 85 | + [`QuantLinear`](https://nvidia.github.io/TensorRT-Model-Optimizer/reference/generated/modelopt.torch.quantization.nn.modules.quant_linear.html#modelopt.torch.quantization.nn.modules.quant_linear.QuantLinear), |
| 86 | + which contain calibration information: |
| 87 | + |
| 88 | + ```py |
| 89 | + # doc: no-output |
| 90 | + mtq.quantize(model, quant_cfg, forward_loop=forward_loop) |
| 91 | + ``` |
| 92 | + |
| 93 | + |
| 94 | +The `amax` attributes of `QuantLinear`'s quantizers specify **dynamic ranges**: |
74 | 95 |
|
75 | | -model = GPT2LMHeadModel.from_pretrained("gpt2") |
| 96 | +```py |
| 97 | +torch_qlinear = model.transformer.h[0].attn.c_attn |
| 98 | +print(torch_qlinear) |
76 | 99 | ``` |
77 | 100 |
|
78 | | -Then, we perform int8 weight-only quantization: |
| 101 | +We must convert dynamic ranges to scaling factors to load them into Tripy: |
79 | 102 |
|
80 | 103 | ```py |
81 | | -from transformers import AutoTokenizer |
82 | | -import modelopt.torch.quantization as mtq |
83 | | - |
84 | | -from modelopt.torch.utils.dataset_utils import create_forward_loop |
85 | | - |
86 | | -# define the modelopt quant configs |
87 | | -quant_cfg = mtq.INT8_DEFAULT_CFG |
88 | | -# disable input quantization for weight-only |
89 | | -# quantized linear modules |
90 | | -quant_cfg["quant_cfg"]["*input_quantizer"] = { |
91 | | - "enable": False, |
92 | | -} |
93 | | - |
94 | | -# define the forward loop for calibration |
95 | | -MAX_SEQ_LEN = 512 |
96 | | -tokenizer = AutoTokenizer.from_pretrained( |
97 | | - "gpt2", |
98 | | - use_fast=True, |
99 | | - model_max_length=MAX_SEQ_LEN, |
100 | | - padding_side="left", |
101 | | - trust_remote_code=True, |
102 | | -) |
103 | | -tokenizer.pad_token = tokenizer.eos_token |
104 | | - |
105 | | -forward_loop = create_forward_loop( |
106 | | - model=model, |
107 | | - dataset_name="cnn_dailymail", |
108 | | - tokenizer=tokenizer, |
109 | | - device=model.device, |
110 | | - num_samples=8, |
111 | | -) |
112 | | - |
113 | | -# call the api for calibration |
114 | | -mtq.quantize(model, quant_cfg, forward_loop=forward_loop) |
| 104 | +def get_scale(quantizer): |
| 105 | + amax = quantizer.export_amax() |
| 106 | + # `maxbound` is the maximum value representible by the data type. |
| 107 | + # For `int8`, this is 127. |
| 108 | + scale = amax.float() / quantizer.maxbound |
| 109 | + return tp.Tensor(scale.squeeze().contiguous()) |
| 110 | +
|
| 111 | +input_scale = get_scale(torch_qlinear.input_quantizer) |
| 112 | +weight_scale = get_scale(torch_qlinear.weight_quantizer) |
115 | 113 | ``` |
116 | 114 |
|
117 | | -`mtq.quantize` replaces all linear layers specified in `quant_cfg` with `QuantLinear` |
118 | | -layers, which contain the calibrated parameters. |
119 | 115 |
|
120 | | -### Load Scales Into The Tripy Model |
| 116 | +## Loading Scales Into Tripy |
| 117 | +
|
| 118 | +### Using Modules |
| 119 | +
|
| 120 | +Modules that support quantization usually: |
| 121 | +- Expose additional model parameters for scales. |
| 122 | +- Accept arguments that control how quantization is performed. |
121 | 123 |
|
122 | | -Let's take a look at one of the `QuantLinear` produced by model optimizer: |
| 124 | +Let's load the scales into an {class}`nvtripy.Linear` module: |
123 | 125 |
|
124 | 126 | ```py |
125 | | -print(model.transformer.h[0].attn.c_attn) |
| 127 | +qlinear = tp.Linear( |
| 128 | + 768, |
| 129 | + 2304, |
| 130 | + # The data type to quantize to: |
| 131 | + quant_dtype=tp.int8, |
| 132 | + # The dimension along which the weights are quantized: |
| 133 | + weight_quant_dim=torch_qlinear.weight_quantizer.axis) |
| 134 | +
|
| 135 | +# Load weights: |
| 136 | +qlinear.weight = tp.Tensor(torch_qlinear.weight.detach().contiguous()) |
| 137 | +qlinear.bias = tp.Tensor(torch_qlinear.bias.detach().contiguous()) |
| 138 | +
|
| 139 | +# Load scaling factors: |
| 140 | +qlinear.input_scale = input_scale |
| 141 | +qlinear.weight_scale = weight_scale |
126 | 142 | ``` |
127 | 143 |
|
128 | | -The `amax` attribute gives us the dynamic range of the tensor. Tripy requires scaling factors, so we can convert it like so: |
| 144 | +:::{note} |
| 145 | +We use scales from ModelOpt here, but scaling factors can come from anywhere. |
| 146 | +::: |
| 147 | + |
| 148 | +We can run it just like a regular `float32` module. |
| 149 | +Inputs/weights are quantized internally: |
129 | 150 |
|
130 | 151 | ```py |
131 | | -def convert_to_scale(amax, maxbound): |
132 | | - return amax.float() / maxbound |
| 152 | +input = tp.ones((1, 768), dtype=tp.float32) |
| 153 | +
|
| 154 | +output = qlinear(input) |
133 | 155 | ``` |
134 | 156 |
|
135 | | -Let's convert the `amax` to the scaling factor and load it to a compatible {class}`nvtripy.Linear` module: |
| 157 | +:::{seealso} |
| 158 | +`load_quant_weights_from_hf` in the [nanoGPT weight loader](source:/examples/nanogpt/weight_loader.py) |
| 159 | +is an example of loading scaling factors for an entire model. |
| 160 | +::: |
| 161 | +
|
| 162 | +
|
| 163 | +### Manually |
136 | 164 |
|
| 165 | +When using {func}`nvtripy.quantize`/{func}`nvtripy.dequantize`, |
| 166 | +`dequantize` must **immediately follow** `quantize`. |
| 167 | +
|
| 168 | +TensorRT will **rotate** `dequantize` over subsequent ops as needed. |
| 169 | +
|
| 170 | +:::{seealso} |
| 171 | +The |
| 172 | +[TensorRT developer guide](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#qdq-placement-recs) |
| 173 | +includes recommendations on placement of quantization and dequantization ops. |
| 174 | +::: |
| 175 | +
|
| 176 | +<!-- We cannot print the quantized input/weight below since that would break Q/DQ rotation --> |
| 177 | +
|
| 178 | +To mimic the behavior of the {class}`nvtripy.Linear` module above, we can: |
| 179 | +
|
| 180 | +1. Quantize the input: |
| 181 | +
|
| 182 | + ```py |
| 183 | + # doc: no-print-locals |
| 184 | + input = tp.ones((1, 768), dtype=tp.float32) |
| 185 | +
|
| 186 | + input = tp.quantize(input, input_scale, dtype=tp.int8) |
| 187 | + # Note the placement of dequantize: |
| 188 | + input = tp.dequantize(input, input_scale, dtype=tp.float32) |
| 189 | + ``` |
| 190 | +
|
| 191 | +2. Quantize the weights: |
| 192 | +
|
| 193 | + ```py |
| 194 | + # doc: no-print-locals |
| 195 | + weight = tp.Tensor(torch_qlinear.weight.detach().contiguous()) |
| 196 | +
|
| 197 | + dim = torch_qlinear.weight_quantizer.axis |
| 198 | + weight = tp.quantize(weight, weight_scale, dtype=tp.int8, dim=dim) |
| 199 | + weight = tp.dequantize(weight, weight_scale, dtype=tp.float32, dim=dim) |
| 200 | + ``` |
| 201 | + |
| 202 | +3. Perform the computation (matrix multiply in this case): |
| 203 | + |
| 204 | + ```py |
| 205 | + # doc: no-print-locals bias |
| 206 | + bias = tp.Tensor(torch_qlinear.bias.detach().contiguous()) |
| 207 | +
|
| 208 | + output = input @ tp.transpose(weight, 0, 1) + bias |
| 209 | + ``` |
| 210 | + |
| 211 | +:::{warning} |
| 212 | +**Evaluating** the tensor produced by `dequantize` will affect accuracy. |
| 213 | + |
| 214 | +- **Why:** Evaluation replaces the tensor with a constant, losing information |
| 215 | + like which op produced it. |
| 216 | + |
| 217 | + So, TensorRT won't see `dequantize` when evaluating subsequent ops and |
| 218 | + won't **rotate** it correctly. |
| 219 | + |
| 220 | +For example, **don't** do this: |
137 | 221 | ```py |
138 | | -# doc: print-locals weight_only_qlinear |
139 | | -weight_only_qlinear = tp.Linear( |
140 | | - 768, |
141 | | - 2304, |
142 | | - quant_dtype=tp.int8, |
143 | | - weight_quant_dim=0, |
144 | | -) |
145 | | -quantizer = model.transformer.h[0].attn.c_attn.weight_quantizer |
146 | | -scale = convert_to_scale(quantizer.export_amax(), quantizer.maxbound) |
147 | | -scale = scale.squeeze().contiguous() |
148 | | -weight_only_qlinear.weight_scale = tp.Tensor(scale) |
149 | | -``` |
| 222 | +# doc: no-eval |
| 223 | +tensor = tp.ones(...) |
150 | 224 |
|
151 | | -For an example of how to load weights from a quantized model, refer to |
152 | | -[load_quant_weights_from_hf](source:/examples/nanogpt/weight_loader.py) from the nanoGPT example. |
| 225 | +tensor = tp.quantize(tensor, ...) |
| 226 | +tensor = tp.dequantize(tensor, ...) |
153 | 227 |
|
154 | | -<!-- Tripy: TEST: IGNORE End --> |
| 228 | +# The `print` below will trigger an evaluation of the tensor which will prevent |
| 229 | +# TensorRT from rotating the dequantization node. This will affect accuracy! |
| 230 | +print(tensor) |
| 231 | +
|
| 232 | +# Rest of the program, including some computation involving tensor |
| 233 | +... |
| 234 | +``` |
| 235 | +::: |
0 commit comments