Skip to content

Commit b445ea0

Browse files
authored
Add a blurb for DeepSeek v3 fine-tuning FP8 Recipe
1 parent f601e42 commit b445ea0

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

docs/explanations/quantization.md

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,4 +141,25 @@ python3 -m MaxText.train src/MaxText/configs/base.yml run_name=$YOUR_JOB_NAME ba
141141
```
142142
Note that `use_qwix_quantization` is not set to `True`.
143143

144-
For further reading, please refer to the [Qwix Read the Docs website](https://qwix.readthedocs.io/en/latest/get_started.html#).
144+
For further reading, please refer to the [Qwix Read the Docs website](https://qwix.readthedocs.io/en/latest/get_started.html#).
145+
146+
147+
## DeepSeek V3 Fine-tuning FP8 Recipe
148+
To improve performance of DeepSeek V3 fine-tuning, we developed a custom FP8 fine-tuning recipe optimized for FP8 throughput. By strategically applying 8-bit matrix multiplication and reducing memory overhead in key operations, this approach realized a speedup of 1.27x against bf16 baseline, with a theotrical headroom of 1.36x. The method prioritizes specific compute-intensive and bandwidth-heavy components while preserving training stability through fine grained scaling strategy.
149+
150+
### Quantization Scope
151+
To realize these gains, the recipe employs a w8a8g8 (8-bit weights, activations and gradients) strategy targeting three primary areas:
152+
153+
* Megablox Kernels: Specifically the `gmm` and `tgmm` operations.
154+
155+
* Attention Projections: Utilizing convolution fusion.
156+
157+
* Communication: Specifically the weight All-Gathers.
158+
159+
### Quantization Headroom
160+
161+
* Attention Projections & Megablox `gmm` kernel (High Operational Intensity): The Megablox `gmm` kernels and Attention Projections are highly compute-bound with high operational intensity. Because their runtime is dominated by the MXU with negligible VPU overhead, they effectively approach the theoretical 2x speedup limit (realizing ~1.9x).
162+
* Megablox `tgmm` kernel (Low Intensity): While the Megablox `tgmm` kernels are also compute-bound, they possess lower operational intensity due to a small reduction dimension (K=1024). This creates a lower roofline ceiling, capping the realized speedup at approximately 1.4x.
163+
164+
### Note on Variance
165+
Please note that these performance gains are derived from the specific configurations of the DeepSeek V3 architecture. Realized FP8 benefits are highly sensitive to model parameters and hardware utilization; consequently, results will vary when this recipe is applied to other models.

0 commit comments

Comments
 (0)