Skip to content

Commit d3c0105

Browse files
committed
fixing convert
Signed-off-by: Peter St. John <[email protected]>
1 parent 4cebd39 commit d3c0105

File tree

6 files changed

+24
-22
lines changed

6 files changed

+24
-22
lines changed

bionemo-recipes/models/amplify/src/amplify/state.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -210,12 +210,12 @@ def scale_weights(ctx):
210210
target.to(cast_dtype)
211211
logger.info(f"Casting model to {cast_dtype} complete.")
212212
else:
213-
assert target_orig_dtypes == extract_dtypes(target.named_parameters()), (
214-
f"dtype mismatch between source and target state dicts. "
215-
f"Left side is { {k: v for k, v in target_orig_dtypes.items() if v != torch.bfloat16} }, "
216-
f"Right side is "
217-
f"{ {k: v for k, v in extract_dtypes(target.named_parameters()).items() if v != torch.bfloat16} }"
218-
)
213+
target_new_dtypes = extract_dtypes(target.named_parameters())
214+
for key in target_orig_dtypes.keys():
215+
if key in target_new_dtypes: # For tied weights, these parameters may disappear.
216+
assert target_orig_dtypes[key] == target_new_dtypes[key], (
217+
f"dtype mismatch for key {key}: {target_orig_dtypes[key]} vs {target_new_dtypes[key]}"
218+
)
219219

220220
return target
221221

bionemo-recipes/models/esm2/src/esm/convert.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,8 @@ def convert_esm_hf_to_te(model_hf: nn.Module, **config_kwargs) -> nn.Module:
7171
_pad_decoder_weights,
7272
_pad_bias,
7373
],
74-
state_dict_ignored_entries=["lm_head.decoder.weight"],
7574
)
7675

77-
output_model.tie_weights()
78-
7976
return output_model
8077

8178

bionemo-recipes/models/esm2/src/esm/state.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -210,12 +210,12 @@ def scale_weights(ctx):
210210
target.to(cast_dtype)
211211
logger.info(f"Casting model to {cast_dtype} complete.")
212212
else:
213-
assert target_orig_dtypes == extract_dtypes(target.named_parameters()), (
214-
f"dtype mismatch between source and target state dicts. "
215-
f"Left side is { {k: v for k, v in target_orig_dtypes.items() if v != torch.bfloat16} }, "
216-
f"Right side is "
217-
f"{ {k: v for k, v in extract_dtypes(target.named_parameters()).items() if v != torch.bfloat16} }"
218-
)
213+
target_new_dtypes = extract_dtypes(target.named_parameters())
214+
for key in target_orig_dtypes.keys():
215+
if key in target_new_dtypes: # For tied weights, these parameters may disappear.
216+
assert target_orig_dtypes[key] == target_new_dtypes[key], (
217+
f"dtype mismatch for key {key}: {target_orig_dtypes[key]} vs {target_new_dtypes[key]}"
218+
)
219219

220220
return target
221221

bionemo-recipes/models/esm2/tests/test_distributed_fp8.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def is_main_process(self) -> bool:
215215
dict_2 = pickle.loads(state_2.detach().numpy(force=True).tobytes())
216216
recipe_1 = dict_1.pop("recipe")
217217
recipe_2 = dict_2.pop("recipe")
218+
breakpoint()
218219
torch.testing.assert_close(dict_1, dict_2)
219220
assert recipe_1 == recipe_2
220221

bionemo-recipes/models/esm2/tests/test_fp8.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,16 @@
2020
from torch.distributed.checkpoint.state_dict import get_model_state_dict
2121
from transformer_engine.common import recipe as recipe_module
2222
from transformer_engine.pytorch import fp8
23-
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
2423

2524
from esm.collator import MLMDataCollatorWithFlattening
2625
from esm.modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM
2726

2827

28+
try:
29+
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
30+
except ImportError: # TE nightly uses a new import path for QuantizedTensor
31+
from transformer_engine.pytorch.quantized_tensor import QuantizedTensor
32+
2933
ALL_RECIPES = [
3034
recipe_module.DelayedScaling(),
3135
recipe_module.Float8CurrentScaling(),

bionemo-recipes/models/llama3/state.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -210,12 +210,12 @@ def scale_weights(ctx):
210210
target.to(cast_dtype)
211211
logger.info(f"Casting model to {cast_dtype} complete.")
212212
else:
213-
assert target_orig_dtypes == extract_dtypes(target.named_parameters()), (
214-
f"dtype mismatch between source and target state dicts. "
215-
f"Left side is { {k: v for k, v in target_orig_dtypes.items() if v != torch.bfloat16} }, "
216-
f"Right side is "
217-
f"{ {k: v for k, v in extract_dtypes(target.named_parameters()).items() if v != torch.bfloat16} }"
218-
)
213+
target_new_dtypes = extract_dtypes(target.named_parameters())
214+
for key in target_orig_dtypes.keys():
215+
if key in target_new_dtypes: # For tied weights, these parameters may disappear.
216+
assert target_orig_dtypes[key] == target_new_dtypes[key], (
217+
f"dtype mismatch for key {key}: {target_orig_dtypes[key]} vs {target_new_dtypes[key]}"
218+
)
219219

220220
return target
221221

0 commit comments

Comments
 (0)