Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions examples/puzzletron/mbridge_distillation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ git clone https://github.com/NVIDIA/Model-Optimizer.git ${MODELOPT_DIR}

**Start Docker container:**

Use the [NeMo 26.02.01 container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo?version=26.02.01):
Use the [NeMo 26.02 container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo?version=26.02):

```bash
# Recommended to mount a workspace directory for storing datasets and distilled models
Expand All @@ -31,7 +31,7 @@ docker run --gpus all -it --rm \
-v ${MODELOPT_DIR}:/opt/Model-Optimizer \
-v ${MODELOPT_DIR}/modelopt:/opt/venv/lib/python3.12/site-packages/modelopt \
-w /opt/Model-Optimizer \
nvcr.io/nvidia/nemo:26.02.01 \
nvcr.io/nvidia/nemo:26.02 \
/bin/bash
```

Expand Down Expand Up @@ -66,12 +66,12 @@ Run distillation directly from HuggingFace checkpoints (student and teacher) wit

```bash
torchrun --nproc_per_node=8 examples/puzzletron/mbridge_distillation/distill_hf.py \
--student_hf_path /path/to/student/huggingface/checkpoint \
--student_hf_path /path/to/student/puzzletron/checkpoint \
--student_hf_model meta-llama/Llama-3.1-8B-Instruct \
--teacher_hf_path /path/to/teacher/huggingface/checkpoint \
--data_paths 1.0 /path/to/hf_datasets/wikitext-103-v1/Salesforce--wikitext_wikitext-103-v1_train_text_document \
--output_dir /path/to/distilled/checkpoint \
--hf-export-path /path/to/exported/hf/model \
--hf-model meta-llama/Llama-3.1-8B-Instruct \
--hf_export_path /path/to/exported/hf/model \
--seq_length 4096 \
--tp_size 8 \
--pp_size 1 \
Expand All @@ -90,7 +90,7 @@ torchrun --nproc_per_node=8 examples/puzzletron/mbridge_distillation/distill_hf.

- Add `--trust_remote_code` if student or teacher checkpoints need HuggingFace custom modeling code.
- The distilled Megatron-Bridge checkpoint will be saved to `--output_dir/checkpoints/iter_<train_iters>`.
- Add `--hf-export-path` (or `--hf_export_path`) to automatically export the final checkpoint to HuggingFace format after distillation. When exporting, you must also provide `--hf-model` / `--hf_model` as the HuggingFace model ID for the export template (e.g., `meta-llama/Llama-3.1-8B-Instruct`). It should match the base architecture of the student model. The exported model can be evaluated for accuracy using the evaluation tools described in the main [README.md](../README.md#evaluation).
- Add `--hf_export_path` to automatically export the final checkpoint to HuggingFace format after distillation. When exporting, you must also provide `--student_hf_model` as the HuggingFace model ID for the export template (e.g., `meta-llama/Llama-3.1-8B-Instruct`). It should match the base architecture of the student model. The exported model can be evaluated for accuracy using the evaluation tools described in the main [README.md](../README.md#evaluation).
- For production use, use larger datasets like [Nemotron-Pretraining-SFT-v1](https://huggingface.co/datasets/nvidia/Nemotron-Pretraining-SFT-v1) and train for more iterations. See the [Megatron-Bridge distillation tutorial](https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/megatron_bridge#distillation) for best practices.

## MMLU Evaluation Results
Expand Down
87 changes: 29 additions & 58 deletions examples/puzzletron/mbridge_distillation/distill_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@

import argparse
import os
import traceback
import shutil

import megatron.bridge.models.distillation_provider
import torch
from megatron.bridge import AutoBridge
from megatron.bridge.models.distillation_provider import convert_to_distillation_provider
from megatron.bridge.recipes.utils.optimizer_utils import (
distributed_fused_adam_with_cosine_annealing,
)
Expand All @@ -40,39 +40,16 @@
TokenizerConfig,
TrainingConfig,
)
from megatron.bridge.training.distill import distill
from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig
from megatron.core.datasets.utils import get_blend_from_list
from megatron.core.distributed import DistributedDataParallelConfig

# Import heterogeneous bridges BEFORE AutoBridge.from_hf_pretrained() is called to ensure
# registration takes precedence. The @MegatronModelBridge.register_bridge decorator registers
# bridges when the module is imported. If both LlamaBridge and PuzzletronLlamaAnyModelBridge
# register for the same source (LlamaForCausalLM), the dispatch system uses the last registration.
#
# Note: Currently, bridges are also registered when distillation_provider is imported
# below (via mbridge/__init__.py), but this import will be needed once DistillationProvider
# is upstreamed to Megatron-Bridge and we no longer import from modelopt.torch.puzzletron.
# Import to register heterogeneous bridges (side effect)
import modelopt.torch.puzzletron.export.mbridge # noqa: F401
import modelopt.torch.utils.distributed as dist

# Use local copy of distillation_provider with fix for heterogeneous models
# TODO: Remove this local copy once fix is upstreamed to Megatron-Bridge
from modelopt.torch.puzzletron.export.mbridge.distillation_provider import (
DistillationProvider,
convert_to_distillation_provider,
)
from modelopt.torch.puzzletron.export.mbridge.export_mbridge_to_hf import (
export_to_hf_and_copy_config,
)
from modelopt.torch.utils import print_rank_0

# Patch upstream module BEFORE importing distill() so isinstance checks work with our local DistillationProvider
# This must happen before distill() is imported because distill.py imports DistillationProvider at module load time
megatron.bridge.models.distillation_provider.DistillationProvider = DistillationProvider

# Import distill() AFTER patching so it uses the patched DistillationProvider
from megatron.bridge.training.distill import distill # noqa: E402

SEED = 1234


Expand All @@ -84,13 +61,13 @@ def get_args():
"--student_hf_path",
type=str,
required=True,
help="HuggingFace model name or path for the student (e.g. Qwen/Qwen3-0.6B)",
help="HuggingFace model name or path for the student (standard HF format or puzzletron any_model format)",
)
parser.add_argument(
"--teacher_hf_path",
type=str,
required=True,
help="HuggingFace model name or path for the teacher (e.g. Qwen/Qwen3-8B)",
help="HuggingFace model name or path for the teacher (standard HF format or puzzletron any_model format)",
)
parser.add_argument("--trust_remote_code", action="store_true", help="Trust remote code")
# Parallelism arguments
Expand Down Expand Up @@ -145,28 +122,30 @@ def get_args():
# Export arguments
parser.add_argument(
"--hf_export_path",
"--hf-export-path",
type=str,
default=None,
help=(
"Path where to save the HuggingFace export. "
"If provided, exports checkpoint to HF format after distillation."
"If provided, exports last iteration checkpoint to HF format after distillation."
),
)
parser.add_argument(
"--hf_model",
"--hf-model",
"--student_hf_model",
type=str,
required=True,
help="HuggingFace model ID to use as template for export (e.g., meta-llama/Llama-3.1-8B-Instruct). "
"Should match the base architecture of the student model.",
required=False,
default=None,
help="HuggingFace model ID to use as template for export (e.g., Qwen/Qwen3-0.6B). "
"Should match the base architecture of the student model if --hf_export_path is provided.",
)
args = parser.parse_args()

# Sanity checks
if not args.use_mock_data and not args.data_paths:
raise ValueError("Must provide either --data_paths or set --use_mock_data.")

if args.hf_export_path and not args.student_hf_model:
raise ValueError("Must provide --student_hf_model if --hf_export_path is provided.")

print_rank_0("\n==================== Arguments ====================")
for k, v in args.__dict__.items():
print_rank_0(f"{k:<35} {v}")
Expand Down Expand Up @@ -288,42 +267,34 @@ def _build_model_provider(hf_path):

# Export to HuggingFace format if hf_export_path is provided
if args.hf_export_path:
# Wait for all ranks to finish distillation before export
if torch.distributed.is_initialized():
torch.distributed.barrier()

print_rank_0(f"Exporting final distilled ckpt to HF format to {args.hf_export_path}")
# Save rank before destroying process group (dist.rank() won't work after destruction)
is_rank_0 = dist.rank() == 0

# Destroy process group on all ranks - export_ckpt will create its own temporary one
# This prevents cleanup from hanging (cleanup tries to barrier, but rank 0 would be gone)
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
dist.cleanup()

# Only rank 0 exports
if is_rank_0:
try:
export_to_hf_and_copy_config(
student_hf_path=args.student_hf_path,
checkpoint_dir=checkpoint_dir,
train_iters=args.train_iters,
hf_export_path=args.hf_export_path,
hf_model=args.hf_model,
trust_remote_code=args.trust_remote_code,
)
except Exception as e:
print(f"⚠️ Export failed: {e}")
traceback.print_exc()
export_bridge = AutoBridge.from_hf_pretrained(
args.student_hf_model, trust_remote_code=args.trust_remote_code
)
export_bridge.export_ckpt(
megatron_path=f"{checkpoint_dir}/iter_{args.train_iters:07d}",
hf_path=args.hf_export_path,
show_progress=True,
strict=True,
)

# save config from student_model to hf_export_path
shutil.copy(f"{args.student_hf_path}/config.json", f"{args.hf_export_path}/config.json")


if __name__ == "__main__":
dist.setup()
args = get_args()
try:
main(args)
except Exception as e:
print_rank_0(f"✗ MAIN FAILED: {type(e).__name__}: {e}")
print_rank_0(f"Traceback:\n{traceback.format_exc()}")
raise
finally:
dist.cleanup()
15 changes: 13 additions & 2 deletions modelopt/torch/puzzletron/export/mbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,20 @@

from megatron.bridge.models.gpt_provider import GPTModelProvider
from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM
from megatron.bridge.models.transformer_config import HeterogeneousTransformerConfig
from megatron.bridge.models.transformer_config import (
HeterogeneousTransformerConfig,
TransformerConfig,
)
from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import (
get_gpt_heterogeneous_layer_spec,
)
from megatron.core.transformer.spec_utils import ModuleSpec

# Monkey-patch: add get_config_for_layer to TransformerConfig if missing
# (needed for non-heterogeneous teacher models in this container version)
if not hasattr(TransformerConfig, "get_config_for_layer"):
TransformerConfig.get_config_for_layer = lambda self, layer_number: self


def heterogeneous_layer_spec(config) -> ModuleSpec:
"""Get GPT heterogeneous layer spec using Transformer Engine."""
Expand Down Expand Up @@ -87,9 +95,12 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> GPTModelProvider
GenericHeterogeneousProvider inherits from GPTModelProvider, which includes all
the fields that the parent bridge sets.
"""

parent_provider = super().provider_bridge(hf_pretrained) # type: ignore[misc]

# If no block_configs, fall back to standard (non-heterogeneous) provider.
if not (hasattr(hf_pretrained.config, "block_configs")):
return parent_provider

provider_kwargs = dataclasses.asdict(parent_provider)

# Filter to only fields that GenericHeterogeneousProvider accepts.
Expand Down
Loading
Loading