Skip to content

Commit bfbf13f

Browse files
committed
Merge remote-tracking branch 'github/dev' into ko3n1g/chore/main-to-dev
2 parents b55a544 + 693587d commit bfbf13f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+2622
-207
lines changed

megatron/core/dist_checkpointing/state_dict_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
StateDict,
1414
apply_factories,
1515
)
16-
from .utils import extract_nonpersistent, extract_sharded_base
16+
from .utils import _clean_metadata_for_serialization, extract_nonpersistent, extract_sharded_base
1717
from .validation import determine_global_metadata, validate_sharding_integrity
1818

1919

@@ -43,6 +43,11 @@ def save_preprocess(
4343
sharded_part = filter_out_empty_flatten_tensor(sharded_part)
4444
if validate_access_integrity:
4545
preprocessed_common_state_dict = common_state_dict
46+
if "content_metadata" in preprocessed_common_state_dict:
47+
preprocessed_common_state_dict["content_metadata"] = _clean_metadata_for_serialization(
48+
preprocessed_common_state_dict["content_metadata"]
49+
)
50+
4651
if preprocess_common_before_consistancy_check:
4752
preprocessed_common_state_dict = preprocess_common_before_consistancy_check(
4853
common_state_dict

megatron/core/dist_checkpointing/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,3 +330,20 @@ def debug_msg(msg: str):
330330
"""
331331
with logger_stack(None, None) as (stacked_name, last_logger):
332332
last_logger.debug(f"{stacked_name} {msg}")
333+
334+
335+
def _clean_metadata_for_serialization(metadata: dict) -> dict:
336+
"""Create a clean copy of metadata for serialization by removing non-serializable objects.
337+
338+
Args:
339+
metadata: Original metadata dict
340+
341+
Returns:
342+
Clean metadata dict suitable for serialization
343+
"""
344+
if metadata is None:
345+
return None
346+
clean_metadata = metadata.copy()
347+
# Remove dp_cp_group as it's not serializable
348+
clean_metadata.pop('dp_cp_group', None)
349+
return clean_metadata

megatron/core/distributed/distributed_data_parallel.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from .. import parallel_state
1010
from ..config_logger import has_config_logger_enabled, log_config_to_disk
11-
from ..fp8_utils import is_float8tensor
11+
from ..fp8_utils import is_float8tensor, post_all_gather_processing
1212
from ..process_groups_config import ProcessGroupCollection
1313
from ..transformer.cuda_graphs import is_graph_capturing
1414
from ..transformer.transformer_config import TransformerConfig
@@ -500,26 +500,34 @@ def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bo
500500

501501
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
502502
bucket_group.start_param_sync(force_sync=force_sync)
503-
# For MXFP8 params, we need to copy the all-gathered param data from the buffer to
504-
# the param.data, since param buffer is not mapped to model params for MXFP8 case.
505-
# The paramaters are cast from bf16 to MXFP8 during copy.
506-
# In the case of "overlap_param_gather=True", the param copy is done
507-
# in "finish_param_sync" stage after zeroing the shared gardient buffers.
508-
if (
509-
self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag
510-
and not self.ddp_config.overlap_param_gather
511-
):
512-
for bucket in bucket_group.buckets:
513-
for param in bucket.params:
514-
param_start, param_end = bucket.param_to_index[param]
515-
param_slice = bucket.param_data.view(-1)[param_start:param_end]
516-
param.data.copy_(param_slice.view(param.data.shape))
517-
# All-gathered params are not needed after being copied to param.data.
518-
# Zero out the param buffer (shared with grad buffer) for gradient accumulation.
519-
# We cannot zero out the entire grad buffer because one grad buffer may
520-
# correspond to multiple param buffers. If we zero out the entire grad buffer,
521-
# it would clear the data of those param buffers that have not yet completed AG.
522-
bucket.param_data.zero_()
503+
504+
if not self.ddp_config.overlap_param_gather:
505+
# For MXFP8 params, we need to copy the all-gathered param data from the buffer to
506+
# the param.data, since param buffer is not mapped to model params for MXFP8 case.
507+
# The paramaters are cast from bf16 to MXFP8 during copy.
508+
# In the case of "overlap_param_gather=True", the param copy is done
509+
# in "finish_param_sync" stage after zeroing the shared gardient buffers.
510+
if self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag:
511+
for bucket in bucket_group.buckets:
512+
for param in bucket.params:
513+
param_start, param_end = bucket.param_to_index[param]
514+
param_slice = bucket.param_data.view(-1)[param_start:param_end]
515+
param.data.copy_(param_slice.view(param.data.shape))
516+
# All-gathered params are not needed after being copied to param.data.
517+
# Zero out the param buffer (shared with grad buffer) for gradient
518+
# accumulation. We cannot zero out the entire grad buffer because one grad
519+
# buffer may correspond to multiple param buffers. If we zero out the entire
520+
# grad buffer, it would clear the data of those param buffers that have not
521+
# yet completed AG.
522+
bucket.param_data.zero_()
523+
else:
524+
fp8_params = []
525+
for bucket in bucket_group.buckets:
526+
for param in bucket.params:
527+
if is_float8tensor(param):
528+
fp8_params.append(param)
529+
if len(fp8_params) > 0:
530+
post_all_gather_processing(fp8_params)
523531

524532
def start_grad_sync(self, *unused):
525533
"""

megatron/core/distributed/param_and_grad_buffer.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
from megatron.core.process_groups_config import ProcessGroupCollection
1818
from megatron.core.rerun_state_machine import get_rerun_state_machine
1919

20-
from ..fp8_utils import is_float8tensor, is_mxfp8tensor, modify_underlying_storage
20+
from ..fp8_utils import (
21+
is_float8tensor,
22+
is_mxfp8tensor,
23+
modify_underlying_storage,
24+
post_all_gather_processing,
25+
)
2126
from ..utils import is_torch_min_version, log_on_each_pipeline_stage
2227
from .distributed_data_parallel_config import DistributedDataParallelConfig
2328
from .reduce_scatter_with_fp32_accumulation import reduce_scatter_with_fp32_accumulation
@@ -311,10 +316,7 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False):
311316
# For the mxfp8_param with "reuse_grad_buf_for_mxfp8_param_ag=True",
312317
# we need to copy the param_data from the shared_param/grad_buffer to param.data
313318
# after the param all-gather.
314-
if (
315-
self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag
316-
and self.ddp_config.overlap_param_gather
317-
):
319+
if self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag:
318320
for bucket in self.buckets:
319321
for param in bucket.params:
320322
param_start, param_end = bucket.param_to_index[param]
@@ -326,6 +328,14 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False):
326328
# correspond to multiple param buffers. If we zero out the entire grad buffer,
327329
# it would clear the data of those param buffers that have not yet completed AG.
328330
bucket.param_data.zero_()
331+
else:
332+
fp8_params = []
333+
for bucket in self.buckets:
334+
for param in bucket.params:
335+
if is_float8tensor(param):
336+
fp8_params.append(param)
337+
if len(fp8_params) > 0:
338+
post_all_gather_processing(fp8_params)
329339

330340
def start_grad_sync(self):
331341
"""

megatron/core/extensions/transformer_engine.py

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from megatron.core.transformer.mlp import MLP
4343
from megatron.core.transformer.transformer_config import TransformerConfig
4444
from megatron.core.transformer.utils import (
45+
ensure_metadata_has_dp_cp_group,
4546
is_layer_window_attention,
4647
make_sharded_tensors_for_checkpoint,
4748
)
@@ -420,6 +421,9 @@ def __init__(
420421
# duplicated across TP ranks
421422
setattr(param, "sequence_parallel", self.config.sequence_parallel)
422423

424+
tp_group = get_tensor_model_parallel_group_if_none(tp_group, is_expert=is_expert)
425+
self._tp_group = tp_group
426+
423427
def forward(self, x):
424428
"""Forward."""
425429
_is_first_microbatch = (
@@ -444,7 +448,14 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
444448
self.parallel_mode is None
445449
), "TELinear sharded_state_dict can only be used with duplicated parallel mode"
446450
state_dict = self.state_dict(prefix="", keep_vars=True)
447-
return make_sharded_tensors_for_checkpoint(state_dict, prefix, None, sharded_offsets)
451+
return make_sharded_tensors_for_checkpoint(
452+
state_dict,
453+
prefix,
454+
None,
455+
sharded_offsets,
456+
tp_group=self._tp_group,
457+
dp_cp_group=metadata["dp_cp_group"],
458+
)
448459

449460
def backward_dw(self):
450461
"""Compute weight gradients during the backward pass if delay_wgrad_compute is enabled."""
@@ -492,6 +503,7 @@ def __init__(
492503

493504
# TODO: For backward compatibility, remove in v0.15.
494505
tp_group = get_tensor_model_parallel_group_if_none(tp_group, is_expert=is_expert)
506+
self._tp_group = tp_group
495507

496508
# TE returns a zero length Tensor when bias=False and
497509
# return_bias=True, but we prefer None. So in that case we
@@ -625,9 +637,15 @@ def forward(self, x):
625637

626638
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
627639
"""Sharding along axis 0, bias sharded"""
640+
metadata = ensure_metadata_has_dp_cp_group(metadata)
628641
state_dict = self.state_dict(prefix="", keep_vars=True)
629642
return make_sharded_tensors_for_checkpoint(
630-
state_dict, prefix, {"weight": 0, "bias": 0}, sharded_offsets
643+
state_dict,
644+
prefix,
645+
{"weight": 0, "bias": 0},
646+
sharded_offsets,
647+
tp_group=self._tp_group,
648+
dp_cp_group=metadata["dp_cp_group"],
631649
)
632650

633651
def __repr__(self):
@@ -670,6 +688,7 @@ def __init__(
670688
if gather_output:
671689
raise ValueError("Transformer Engine linear layers do not support gather_output = True")
672690
tp_group = get_tensor_model_parallel_group_if_none(tp_group, is_expert=is_expert)
691+
self._tp_group = tp_group
673692
world_size = get_pg_size(tp_group)
674693
rank = get_pg_rank(tp_group)
675694

@@ -720,7 +739,12 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
720739
"""Sharding along axis 0, bias sharded"""
721740
state_dict = self.state_dict(prefix="", keep_vars=True)
722741
return make_sharded_tensors_for_checkpoint(
723-
state_dict, prefix, {"weight": 0, "bias": 0}, sharded_offsets
742+
state_dict,
743+
prefix,
744+
{"weight": 0, "bias": 0},
745+
sharded_offsets,
746+
tp_group=self._tp_group,
747+
dp_cp_group=metadata["dp_cp_group"],
724748
)
725749

726750
def __repr__(self):
@@ -764,6 +788,7 @@ def __init__(
764788
"Transformer Engine linear layers do not support input_is_parallel = False"
765789
)
766790
tp_group = get_tensor_model_parallel_group_if_none(tp_group, is_expert=is_expert)
791+
self._tp_group = tp_group
767792

768793
super().__init__(
769794
input_size=input_size,
@@ -814,7 +839,12 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
814839
"""Sharding along axis 1, bias not sharded"""
815840
state_dict = self.state_dict(prefix="", keep_vars=True)
816841
return make_sharded_tensors_for_checkpoint(
817-
state_dict, prefix, {"weight": 1}, sharded_offsets
842+
state_dict,
843+
prefix,
844+
{"weight": 1},
845+
sharded_offsets,
846+
tp_group=self._tp_group,
847+
dp_cp_group=metadata["dp_cp_group"],
818848
)
819849

820850
def __repr__(self):
@@ -901,6 +931,7 @@ def __init__(
901931
assert hasattr(
902932
pg_collection, "hcp"
903933
), "TEDotProductAttention pg_collection must have hierarchical cp pg"
934+
self._tp_group = pg_collection.tp
904935

905936
if is_te_min_version("0.10.0"):
906937
extra_kwargs["attention_type"] = attention_type
@@ -1078,7 +1109,12 @@ def sharded_state_dict(
10781109
else:
10791110
state_dict = {}
10801111
return make_sharded_tensors_for_checkpoint(
1081-
state_dict, prefix, {'softmax_offset': 0}, sharded_offsets
1112+
state_dict,
1113+
prefix,
1114+
{'softmax_offset': 0},
1115+
sharded_offsets,
1116+
tp_group=self._tp_group,
1117+
dp_cp_group=metadata["dp_cp_group"],
10821118
)
10831119

10841120

@@ -1138,6 +1174,7 @@ def __init__(
11381174
# The comms between TP and EP group is explicitly handled by MoE token dispatcher.
11391175
# So we disable comms by making TE agnostic of model parallel.
11401176
tp_group = get_tensor_model_parallel_group_if_none(tp_group, is_expert=is_expert)
1177+
self._tp_group = tp_group
11411178
tp_size = get_pg_size(tp_group)
11421179

11431180
self.explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel)
@@ -1372,7 +1409,12 @@ def _sharded_state_dict_grouped(
13721409
(ep_axis, global_expert_idx, num_global_experts),
13731410
)
13741411
sub_sd = make_sharded_tensors_for_checkpoint(
1375-
state_dict, '', tp_axis_map, new_sharded_offsets
1412+
state_dict,
1413+
'',
1414+
tp_axis_map,
1415+
new_sharded_offsets,
1416+
tp_group=self._tp_group,
1417+
dp_cp_group=metadata["dp_cp_group"],
13761418
)
13771419
# Remove expert layers indexing from sharded keys
13781420
replace_prefix_for_sharding(sub_sd, f"{gemm_idx}.", expert_prefix)

megatron/core/fp8_utils.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@
8585
Fp8Padding = None
8686
Fp8Unpadding = None
8787

88+
try:
89+
from transformer_engine.pytorch.tensor.utils import (
90+
post_all_gather_processing as te_post_all_gather_processing,
91+
)
92+
except ImportError:
93+
te_post_all_gather_processing = None
94+
8895

8996
def is_float8tensor(tensor: torch.Tensor) -> bool:
9097
"""Check if a tensor is a Transformer Engine Float8Tensor.
@@ -247,7 +254,15 @@ def _quantize_param_shard_impl(
247254
raise NotImplementedError(
248255
f"FSDP with --fp8-param-gather is not supported in TE v{get_te_version()}"
249256
)
250-
cast_master_weights_to_fp8(*args)
257+
258+
# For newer TE versions (i.e., have post_all_gather_processing function), we keep the
259+
# columnwise data and manually call post_all_gather_processing after all-gather, this
260+
# makes fp8 params compatible with CUDA graph.
261+
kwargs = {}
262+
if te_post_all_gather_processing is not None:
263+
kwargs["manual_post_all_gather_processing"] = True
264+
265+
cast_master_weights_to_fp8(*args, **kwargs)
251266

252267
def _correct_amax_history_if_needed_impl(model: List[torch.nn.Module]) -> None:
253268
pass
@@ -481,6 +496,20 @@ def correct_amax_history_if_needed(model: List[torch.nn.Module]):
481496
_correct_amax_history_if_needed_impl(model)
482497

483498

499+
def post_all_gather_processing(model_params):
500+
"""
501+
Post-processing after all-gather for weights in distributed optimizer.
502+
- tensorwise: may need to create a transposed view to match backend GEMM.
503+
- blockwise: create column-wise storage.
504+
"""
505+
if te_post_all_gather_processing is not None:
506+
te_post_all_gather_processing(model_params)
507+
else:
508+
# If the TE version is old and does not have post_all_gather_processing function, this is
509+
# a no-op, and the transpose/columnwise data will be created in the next forward pass.
510+
pass
511+
512+
484513
def is_first_last_bf16_layer(config: TransformerConfig, layer_no: int):
485514
"""Check if the layer is in bf16."""
486515
num_bf16_layers_at_start = (

megatron/core/models/bert/bert_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
1515
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
1616
from megatron.core.models.common.language_module.language_module import LanguageModule
17+
from megatron.core.process_groups_config import ProcessGroupCollection
1718
from megatron.core.transformer.dot_product_attention import (
1819
DotProductAttention as MCoreDotProductAttention,
1920
)
@@ -73,9 +74,10 @@ def __init__(
7374
seq_len_interpolation_factor: Optional[float] = None,
7475
add_binary_head=True,
7576
return_embeddings=False,
77+
pg_collection: Optional[ProcessGroupCollection] = None,
7678
vp_stage: Optional[int] = None,
7779
):
78-
super(BertModel, self).__init__(config=config)
80+
super(BertModel, self).__init__(config=config, pg_collection=pg_collection)
7981

8082
if has_config_logger_enabled(config):
8183
log_config_to_disk(config, locals(), prefix=type(self).__name__)

0 commit comments

Comments
 (0)