Skip to content

Commit 76ebe7c

Browse files
authored
Update quantization.md
1 parent b445ea0 commit 76ebe7c

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

docs/explanations/quantization.md

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ For further reading, please refer to the [Qwix Read the Docs website](https://qw
145145

146146

147147
## DeepSeek V3 Fine-tuning FP8 Recipe
148-
To improve performance of DeepSeek V3 fine-tuning, we developed a custom FP8 fine-tuning recipe optimized for FP8 throughput. By strategically applying 8-bit matrix multiplication and reducing memory overhead in key operations, this approach realized a speedup of 1.27x against bf16 baseline, with a theotrical headroom of 1.36x. The method prioritizes specific compute-intensive and bandwidth-heavy components while preserving training stability through fine grained scaling strategy.
148+
To improve the performance of DeepSeek V3 fine-tuning, we developed a custom recipe optimized for FP8 throughput. The method prioritizes specific compute-intensive and bandwidth-heavy components while preserving training stability through a fine-grained scaling strategy.
149149

150150
### Quantization Scope
151151
To realize these gains, the recipe employs a w8a8g8 (8-bit weights, activations and gradients) strategy targeting three primary areas:
@@ -156,10 +156,20 @@ To realize these gains, the recipe employs a w8a8g8 (8-bit weights, activations
156156

157157
* Communication: Specifically the weight All-Gathers.
158158

159-
### Quantization Headroom
159+
### FP8 Recipe
160+
* Rounding: rounding to nearest even
161+
* Precision
162+
* Activations and weights: e4m3fn
163+
* Gradients:e5m2
164+
* Scaling granularity: per-axis
165+
* Scaling mode:
166+
* static for weights and activations
167+
* dynamic for gradients
160168

161-
* Attention Projections & Megablox `gmm` kernel (High Operational Intensity): The Megablox `gmm` kernels and Attention Projections are highly compute-bound with high operational intensity. Because their runtime is dominated by the MXU with negligible VPU overhead, they effectively approach the theoretical 2x speedup limit (realizing ~1.9x).
162-
* Megablox `tgmm` kernel (Low Intensity): While the Megablox `tgmm` kernels are also compute-bound, they possess lower operational intensity due to a small reduction dimension (K=1024). This creates a lower roofline ceiling, capping the realized speedup at approximately 1.4x.
169+
### Convergence
170+
To validate this recipe, we utilized MaxText following the MLPerf Training framework by MLCommons to ensure a reproducible and standardized evaluation. Using the C4 dataset (loaded via TFDS) as the reference corpus, we tracked convergence by monitoring validation loss on a held-out split. This aligns with MLPerf’s time-to-quality principle, where the primary metric is the speed at which the model achieves target quality.
163171

164-
### Note on Variance
165-
Please note that these performance gains are derived from the specific configurations of the DeepSeek V3 architecture. Realized FP8 benefits are highly sensitive to model parameters and hardware utilization; consequently, results will vary when this recipe is applied to other models.
172+
For this specific case, we derived our training duration from the MLPerf 405B benchmark, targeting roughly 2–3 billion tokens after resuming from a checkpoint. In our configuration, we executed 300 steps with a sequence length of 4096 and a global batch size of 2048, resulting in a total of approximately 2.5 billion tokens.
173+
174+
### Performance Sensitivity
175+
Please note that the FP8 benefits are highly sensitive to model parameters, the efficiency of the BF16 baseline, and hardware utilization; consequently, results will vary when this recipe is applied to other models. Any variance in these factors shifts the ratio of compute-bound to memory-bound operations, directly altering the potential gains.

0 commit comments

Comments
 (0)