Skip to content

Commit f7fdd1e

Browse files
authored
expanded esm2 fp8 tests (#1330)
These are currently passing in `gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-pytorch-py3-base` but may require some additional changes for the TE that ships in 25.10 Signed-off-by: Peter St. John <[email protected]>
1 parent 51e1f10 commit f7fdd1e

File tree

2 files changed

+93
-75
lines changed

2 files changed

+93
-75
lines changed

.devcontainer/recipes/Dockerfile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# Uncomment to use the latest TE from the NGC registry for debugging changes with latest TE.
2+
# FROM gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-pytorch-py3-base
13
FROM nvcr.io/nvidia/pytorch:25.10-py3
24
RUN --mount=type=cache,target=/root/.cache/pip \
35
--mount=type=bind,source=requirements.txt,target=/workspace/requirements.txt \

bionemo-recipes/models/esm2/tests/test_fp8.py

Lines changed: 91 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -18,27 +18,67 @@
1818
import torch.distributed.checkpoint as dcp
1919
import transformer_engine
2020
from torch.distributed.checkpoint.state_dict import get_model_state_dict
21-
from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling
22-
from transformer_engine.pytorch.fp8 import check_fp8_support, check_mxfp8_support
23-
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
21+
from transformer_engine.common import recipe as recipe_module
22+
from transformer_engine.pytorch import fp8
23+
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
2424

2525
from esm.collator import MLMDataCollatorWithFlattening
2626
from esm.modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM
2727

2828

29-
def requires_fp8(func):
30-
"""Decorator to skip tests that require FP8 support."""
31-
fp8_available, reason = check_fp8_support()
32-
return pytest.mark.skipif(not fp8_available, reason=f"FP8 is not supported on this GPU: {reason}")(func)
33-
34-
35-
def requires_mxfp8(func):
36-
"""Decorator to skip tests that require MXFP8 support."""
37-
mxfp8_available, reason = check_mxfp8_support()
38-
if torch.cuda.get_device_capability() == (12, 0):
39-
mxfp8_available = False
40-
reason = "MXFP8 is not supported on sm120"
41-
return pytest.mark.skipif(not mxfp8_available, reason=f"MXFP8 is not supported on this GPU: {reason}")(func)
29+
ALL_RECIPES = [
30+
recipe_module.DelayedScaling(),
31+
recipe_module.Float8CurrentScaling(),
32+
recipe_module.Float8BlockScaling(),
33+
recipe_module.MXFP8BlockScaling(),
34+
# recipe_module.NVFP4BlockScaling(disable_rht=True, disable_stochastic_rounding=True),
35+
]
36+
37+
38+
def _check_recipe_support(recipe: recipe_module.Recipe):
39+
"""Check if a recipe is supported and return (supported, reason)."""
40+
if isinstance(recipe, recipe_module.DelayedScaling):
41+
recipe_supported, reason = fp8.check_fp8_support()
42+
elif isinstance(recipe, recipe_module.Float8CurrentScaling):
43+
recipe_supported, reason = fp8.check_fp8_support()
44+
elif isinstance(recipe, recipe_module.Float8BlockScaling):
45+
recipe_supported, reason = fp8.check_fp8_block_scaling_support()
46+
elif isinstance(recipe, recipe_module.MXFP8BlockScaling):
47+
recipe_supported, reason = fp8.check_mxfp8_support()
48+
elif isinstance(recipe, recipe_module.NVFP4BlockScaling):
49+
recipe_supported, reason = fp8.check_nvfp4_support()
50+
else:
51+
recipe_supported = False
52+
reason = "Unsupported recipe"
53+
return recipe_supported, reason
54+
55+
56+
def requires_recipe_support(recipe: recipe_module.Recipe):
57+
"""Decorator to skip tests that require recipe support."""
58+
59+
def requires_recipe_support_inner(func):
60+
recipe_supported, reason = _check_recipe_support(recipe)
61+
return pytest.mark.skipif(not recipe_supported, reason=reason)(func)
62+
63+
return requires_recipe_support_inner
64+
65+
66+
def parametrize_recipes_with_support(recipes):
67+
"""Generate pytest.param objects with skip marks for unsupported recipes."""
68+
parametrized_recipes = []
69+
for recipe in recipes:
70+
recipe_supported, reason = _check_recipe_support(recipe)
71+
parametrized_recipes.append(
72+
pytest.param(
73+
recipe,
74+
id=recipe.__class__.__name__,
75+
marks=pytest.mark.skipif(
76+
not recipe_supported,
77+
reason=reason,
78+
),
79+
)
80+
)
81+
return parametrized_recipes
4282

4383

4484
@pytest.fixture
@@ -53,24 +93,30 @@ def input_data_thd(tokenizer, tokenized_proteins):
5393
return data_collator(tokenized_proteins)
5494

5595

56-
@requires_fp8
57-
def test_fp8_forward_and_backward_pass(te_model_checkpoint, input_data):
96+
@pytest.mark.parametrize("fp8_recipe", parametrize_recipes_with_support(ALL_RECIPES))
97+
def test_fp8_forward_and_backward_pass(te_model_checkpoint, input_data, fp8_recipe):
5898
model_te = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16)
5999
model_te.to("cuda")
60100

61101
input_data = {k: v.to("cuda") for k, v in input_data.items()}
62102
outputs = model_te(**input_data)
63103

64-
fp8_recipe = DelayedScaling()
65104
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
66105
outputs_fp8 = model_te(**input_data)
67106
outputs_fp8.loss.backward()
68107

69-
torch.testing.assert_close(outputs_fp8.loss, outputs.loss)
108+
if isinstance(fp8_recipe, recipe_module.NVFP4BlockScaling):
109+
atol = 0.2
110+
rtol = 0.05
111+
else:
112+
atol = None
113+
rtol = None
114+
115+
torch.testing.assert_close(outputs_fp8.loss, outputs.loss, atol=atol, rtol=rtol)
70116

71117

72-
@requires_fp8
73-
def test_fp8_forward_and_backward_pass_thd(te_model_checkpoint, input_data_thd, monkeypatch):
118+
@pytest.mark.parametrize("fp8_recipe", parametrize_recipes_with_support(ALL_RECIPES))
119+
def test_fp8_forward_and_backward_pass_thd(te_model_checkpoint, input_data_thd, fp8_recipe, monkeypatch):
74120
if torch.cuda.get_device_capability() == (12, 0):
75121
# TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default,
76122
# but it's missing this THD implementation.
@@ -82,54 +128,27 @@ def test_fp8_forward_and_backward_pass_thd(te_model_checkpoint, input_data_thd,
82128
input_data = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd.items()}
83129
outputs = model_te(**input_data)
84130

85-
fp8_recipe = DelayedScaling()
86131
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
87132
outputs_fp8 = model_te(**input_data)
88133
outputs_fp8.loss.backward()
89134

90-
torch.testing.assert_close(outputs_fp8.loss, outputs.loss)
91-
92-
93-
@requires_mxfp8
94-
def test_mxfp8_forward_and_backward_pass(te_model_checkpoint, input_data):
95-
model_te = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16)
96-
model_te.to("cuda")
97-
98-
input_data = {k: v.to("cuda") for k, v in input_data.items()}
99-
outputs = model_te(**input_data)
100-
101-
mxfp8_recipe = MXFP8BlockScaling()
102-
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=mxfp8_recipe):
103-
outputs_fp8 = model_te(**input_data)
104-
outputs_fp8.loss.backward()
105-
106-
torch.testing.assert_close(outputs_fp8.loss, outputs.loss)
135+
if isinstance(fp8_recipe, recipe_module.NVFP4BlockScaling):
136+
atol = 0.2
137+
rtol = 0.05
138+
else:
139+
atol = None
140+
rtol = None
107141

142+
torch.testing.assert_close(outputs_fp8.loss, outputs.loss, atol=atol, rtol=rtol)
108143

109-
@requires_mxfp8
110-
def test_mxfp8_forward_and_backward_pass_thd(te_model_checkpoint, input_data_thd):
111-
model_te = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, attn_input_format="thd", dtype=torch.bfloat16)
112-
model_te.to("cuda")
113-
114-
input_data = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd.items()}
115-
outputs = model_te(**input_data)
116144

117-
mxfp8_recipe = MXFP8BlockScaling()
118-
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=mxfp8_recipe):
119-
outputs_fp8 = model_te(**input_data)
120-
outputs_fp8.loss.backward()
121-
122-
torch.testing.assert_close(outputs_fp8.loss, outputs.loss)
123-
124-
125-
@requires_fp8
126-
def test_fp8_model_init_forward_and_backward(te_model_checkpoint, input_data):
127-
fp8_recipe = DelayedScaling()
145+
@pytest.mark.parametrize("fp8_recipe", parametrize_recipes_with_support(ALL_RECIPES))
146+
def test_fp8_model_init_forward_and_backward(te_model_checkpoint, input_data, fp8_recipe):
128147
config = NVEsmConfig.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16)
129148
with transformer_engine.pytorch.fp8_model_init(enabled=True, recipe=fp8_recipe):
130149
model_te = NVEsmForMaskedLM(config)
131150

132-
assert isinstance(model_te.lm_head.dense.weight, Float8Tensor)
151+
assert isinstance(model_te.lm_head.dense.weight, QuantizedTensor)
133152

134153
model_te.to("cuda")
135154
input_data = {k: v.to("cuda") for k, v in input_data.items()}
@@ -140,39 +159,34 @@ def test_fp8_model_init_forward_and_backward(te_model_checkpoint, input_data):
140159
outputs_fp8.loss.backward()
141160

142161

143-
@requires_fp8
144162
@pytest.mark.xfail(reason="BIONEMO-3055: fp8 model init and pretrained loading is not currently supported.")
145-
def test_fp8_model_init_from_pretrained(te_model_checkpoint, input_data):
146-
fp8_recipe = DelayedScaling()
147-
163+
@pytest.mark.parametrize("fp8_recipe", parametrize_recipes_with_support(ALL_RECIPES))
164+
def test_fp8_model_init_from_pretrained(te_model_checkpoint, fp8_recipe):
148165
# TODO: this will be renamed to quantized_model_init in the future, fp8_model_init will be removed in 3.0
149166
with transformer_engine.pytorch.fp8_model_init(enabled=True, recipe=fp8_recipe):
150-
model_te = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint)
167+
model_te = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16)
151168

152-
assert isinstance(model_te.esm.encoder.layers[0].layernorm_mlp.fc2_weight, Float8Tensor)
153-
assert isinstance(model_te.lm_head.dense.weight, Float8Tensor)
169+
assert isinstance(model_te.esm.encoder.layers[0].layernorm_mlp.fc2_weight, QuantizedTensor)
170+
assert isinstance(model_te.lm_head.dense.weight, QuantizedTensor)
154171

155172

156-
@requires_fp8
157173
@pytest.mark.xfail(reason="BIONEMO-3055: fp8 model init and pretrained saving is not currently supported.")
158-
def test_fp8_model_init_save_pretrained(te_model_checkpoint, tmp_path):
159-
fp8_recipe = DelayedScaling()
174+
@pytest.mark.parametrize("fp8_recipe", parametrize_recipes_with_support(ALL_RECIPES))
175+
def test_fp8_model_init_save_pretrained(te_model_checkpoint, tmp_path, fp8_recipe):
160176
config = NVEsmConfig.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16)
161177
with transformer_engine.pytorch.fp8_model_init(enabled=True, recipe=fp8_recipe):
162178
model_fp8 = NVEsmForMaskedLM(config)
163179

164-
assert isinstance(model_fp8.esm.encoder.layers[0].layernorm_mlp.fc2_weight, Float8Tensor)
165-
assert isinstance(model_fp8.lm_head.dense.weight, Float8Tensor)
180+
assert isinstance(model_fp8.esm.encoder.layers[0].layernorm_mlp.fc2_weight, QuantizedTensor)
181+
assert isinstance(model_fp8.lm_head.dense.weight, QuantizedTensor)
166182

167183
model_fp8.save_pretrained(tmp_path / "fp8_checkpoint")
168184
del model_fp8
169185
NVEsmForMaskedLM.from_pretrained(tmp_path / "fp8_checkpoint", dtype=torch.bfloat16)
170186

171187

172-
@requires_fp8
173-
@pytest.mark.xfail(reason="BIONEMO-3055: fp8 model init and distributed checkpointing is not currently supported.")
174-
def test_fp8_model_distributed_checkpointing_save_and_load(te_model_checkpoint, tmp_path, input_data):
175-
fp8_recipe = DelayedScaling()
188+
@pytest.mark.parametrize("fp8_recipe", parametrize_recipes_with_support(ALL_RECIPES))
189+
def test_fp8_model_distributed_checkpointing_save_and_load(te_model_checkpoint, tmp_path, input_data, fp8_recipe):
176190
config = NVEsmConfig.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16)
177191
with transformer_engine.pytorch.fp8_model_init(enabled=True, recipe=fp8_recipe):
178192
model_fp8 = NVEsmForMaskedLM(config)
@@ -184,6 +198,7 @@ def test_fp8_model_distributed_checkpointing_save_and_load(te_model_checkpoint,
184198
outputs.loss.backward()
185199

186200
state_dict = get_model_state_dict(model_fp8)
201+
state_dict = {key: val for key, val in state_dict.items() if not key.endswith("_extra_state")}
187202
dcp.save(state_dict, checkpoint_id=tmp_path / "fp8_checkpoint")
188203

189204
del model_fp8, state_dict
@@ -192,4 +207,5 @@ def test_fp8_model_distributed_checkpointing_save_and_load(te_model_checkpoint,
192207
model_fp8 = NVEsmForMaskedLM(config)
193208

194209
state_dict = model_fp8.state_dict()
210+
state_dict = {key: val for key, val in state_dict.items() if not key.endswith("_extra_state")}
195211
dcp.load(state_dict, checkpoint_id=tmp_path / "fp8_checkpoint")

0 commit comments

Comments
 (0)