You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/explanations/quantization.md
+31-1Lines changed: 31 additions & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -141,4 +141,34 @@ python3 -m MaxText.train src/MaxText/configs/base.yml run_name=$YOUR_JOB_NAME ba
141
141
```
142
142
Note that `use_qwix_quantization` is not set to `True`.
143
143
144
-
For further reading, please refer to the [Qwix Read the Docs website](https://qwix.readthedocs.io/en/latest/get_started.html#).
144
+
For further reading, please refer to the [Qwix Read the Docs website](https://qwix.readthedocs.io/en/latest/get_started.html#).
145
+
146
+
## DeepSeek V3 Fine-tuning FP8 Recipe
147
+
To improve the performance of DeepSeek V3 fine-tuning, we developed a custom recipe optimized for FP8 throughput. The method prioritizes specific compute-intensive and bandwidth-heavy components while preserving training stability through a fine-grained scaling strategy.
148
+
149
+
### Quantization Scope
150
+
To realize these gains, the recipe employs a w8a8g8 (8-bit weights, activations and gradients) strategy targeting three primary areas:
151
+
152
+
* Megablox Kernels: Specifically the `gmm` and `tgmm` operations.
* Communication: Specifically the weight All-Gathers.
157
+
158
+
### FP8 Recipe
159
+
* Rounding: rounding to nearest even
160
+
* Precision
161
+
* Activations and weights: e4m3fn
162
+
* Gradients: e5m2
163
+
* Scaling granularity: per-axis
164
+
* Scaling mode:
165
+
* static for weights and activations
166
+
* dynamic for gradients
167
+
168
+
### Convergence
169
+
To validate this recipe, we utilized MaxText following the MLPerf Training framework by MLCommons to ensure a reproducible and standardized evaluation. Using the C4 dataset (loaded via TFDS) as the reference corpus, we tracked convergence by monitoring validation loss on a held-out split. This aligns with MLPerf’s time-to-quality principle, where the primary metric is the speed at which the model achieves target quality.
170
+
171
+
For this specific case, we derived our training duration from the MLPerf 405B benchmark, targeting roughly 2–3 billion tokens after resuming from a checkpoint. In our configuration, we executed 300 steps with a sequence length of 4096 and a global batch size of 2048, resulting in a total of approximately 2.5 billion tokens.
172
+
173
+
### Performance Sensitivity
174
+
Please note that the FP8 benefits are highly sensitive to model parameters, the efficiency of the BF16 baseline, and hardware utilization; consequently, results will vary when this recipe is applied to other models. Any variance in these factors shifts the ratio of compute-bound to memory-bound operations, directly altering the potential gains.
0 commit comments