diff --git a/megatron/core/pipeline_parallel/bridge_communicator.py b/megatron/core/pipeline_parallel/bridge_communicator.py index f1e74a2f16..9580f3924c 100644 --- a/megatron/core/pipeline_parallel/bridge_communicator.py +++ b/megatron/core/pipeline_parallel/bridge_communicator.py @@ -494,7 +494,7 @@ def recv_backward(self) -> torch.Tensor: received_gradients_list.append(grad_tensor) # Concatenate received gradients - aggregated_gradient = torch.cat(received_gradients_list, dim=0) + aggregated_gradient = torch.cat(received_gradients_list, dim=self.dim_mapping['b']) logging.debug( f"[Bridge Communicator] [receive_backward] Rank {self.current_rank} " f"agg grad shape {aggregated_gradient.shape} sum {aggregated_gradient.sum()}" @@ -615,7 +615,7 @@ def send_forward_recv_backward( req.wait() # Concatenate received gradients - aggregated_gradient = torch.cat(received_gradients_list, dim=0) + aggregated_gradient = torch.cat(received_gradients_list, dim=self.dim_mapping['b']) logging.debug( f"[Bridge Communicator] [send_forward_recv_backward] Rank {self.current_rank} " f"agg grad shape {aggregated_gradient.shape} sum {aggregated_gradient.sum()}" @@ -737,7 +737,9 @@ def send_backward_recv_forward( req.wait() # Concatenate received activations - aggregated_activation = torch.cat(received_activations_list, dim=0) + aggregated_activation = torch.cat( + received_activations_list, dim=self.dim_mapping['b'] + ) logging.debug( f"[Bridge Communicator] [send_backward_recv_backward] Rank {self.current_rank} " f"agg act shape {aggregated_activation.shape} sum {aggregated_activation.sum()}" diff --git a/megatron/core/pipeline_parallel/multimodule_communicator.py b/megatron/core/pipeline_parallel/multimodule_communicator.py index dfda270ef7..4733db6511 100644 --- a/megatron/core/pipeline_parallel/multimodule_communicator.py +++ b/megatron/core/pipeline_parallel/multimodule_communicator.py @@ -16,6 +16,32 @@ Shape = Union[List[int], torch.Size] +def _ensure_3d_tensor(tensor): + """Ensure tensor is 3D for P2P/bridge communication. + + P2P and bridge communicators expect 3D tensors. + Handles both single tensors and lists of tensors (for VPP). + """ + if isinstance(tensor, list): + return [_ensure_3d_tensor(t) for t in tensor] + if isinstance(tensor, torch.Tensor) and tensor.ndim == 2: + return tensor.unsqueeze(-1) + return tensor + + +def _restore_tensor_shape(tensor): + """Restore original tensor shape after P2P/bridge communication. + + Remove the extra dimension added by _ensure_3d_tensor if it was singleton. + Handles both single tensors and lists of tensors (for VPP). + """ + if isinstance(tensor, list): + return [_restore_tensor_shape(t) for t in tensor] + if isinstance(tensor, torch.Tensor) and tensor.ndim == 3 and tensor.shape[-1] == 1: + return tensor.squeeze(-1) + return tensor + + @dataclass class RankModuleInfo: """Information about a rank in a module.""" @@ -281,12 +307,12 @@ def send_forward(self, output_dict: Dict[str, torch.Tensor], is_last_stage: bool # If last stage, and has outgoing modules, send forward activation # by using bridge communicator. for bridge_comm in rank_module_info.bridge_comms_as_src_module: - bridge_comm.send_forward(output_dict[module_name]) + tensor_to_send = _ensure_3d_tensor(output_dict[module_name]) + bridge_comm.send_forward(tensor_to_send) else: # If not last stage, send forward activation by using P2P communicator. - rank_module_info.p2p_communicator.send_forward( - output_dict[module_name], is_last_stage=False - ) + tensor_to_send = _ensure_3d_tensor(output_dict[module_name]) + rank_module_info.p2p_communicator.send_forward(tensor_to_send, is_last_stage=False) def send_forward_recv_backward( self, @@ -303,28 +329,23 @@ def send_forward_recv_backward( Returns: A dictionary mapping module names to tensors. """ - logging.debug( - f"[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] " - f"[send_forward_recv_backward] output_dict keys: {output_dict.keys()}, " - f"tensor_shape: {tensor_shape}, is_last_stage: {is_last_stage}" - ) grad_dict = {} for module_name, rank_module_info in self.rank_module_map.items(): if rank_module_info.pp_rank == rank_module_info.pp_size - 1: # If last stage, and has outgoing modules, send forward activation and # receive backward gradient by using bridge communicator. for bridge_comm in rank_module_info.bridge_comms_as_src_module: - grad_dict[bridge_comm.src_module_name] = bridge_comm.send_forward_recv_backward( - output_dict[module_name] - ) + tensor_to_send = _ensure_3d_tensor(output_dict[module_name]) + grad_tensor = bridge_comm.send_forward_recv_backward(tensor_to_send) + grad_dict[bridge_comm.src_module_name] = _restore_tensor_shape(grad_tensor) else: # If not last stage, send forward activation and receive backward gradient # by using P2P communicator. - grad_dict[module_name] = ( - rank_module_info.p2p_communicator.send_forward_recv_backward( - output_dict[module_name], tensor_shapes=tensor_shape, is_last_stage=False - ) + tensor_to_send = _ensure_3d_tensor(output_dict[module_name]) + grad_tensor = rank_module_info.p2p_communicator.send_forward_recv_backward( + tensor_to_send, tensor_shapes=tensor_shape, is_last_stage=False ) + grad_dict[module_name] = _restore_tensor_shape(grad_tensor) return grad_dict def send_backward_recv_forward( @@ -353,19 +374,19 @@ def send_backward_recv_forward( for bridge_comm in rank_module_info.bridge_comms_as_dest_module: # If first stage, and has incoming modules, send backward gradient and # receive forward activation by using bridge communicator. - input_dict[bridge_comm.src_module_name] = ( - bridge_comm.send_backward_recv_forward( - grad_dict[bridge_comm.src_module_name] - ) + grad_to_send = _ensure_3d_tensor(grad_dict[bridge_comm.src_module_name]) + activation_tensor = bridge_comm.send_backward_recv_forward(grad_to_send) + input_dict[bridge_comm.src_module_name] = _restore_tensor_shape( + activation_tensor ) else: # If not first stage, send backward gradient and receive forward activation # by using P2P communicator. - input_dict[module_name] = ( - rank_module_info.p2p_communicator.send_backward_recv_forward( - grad_dict[module_name], tensor_shapes=tensor_shape, is_first_stage=False - ) + grad_to_send = _ensure_3d_tensor(grad_dict[module_name]) + activation_tensor = rank_module_info.p2p_communicator.send_backward_recv_forward( + grad_to_send, tensor_shapes=tensor_shape, is_first_stage=False ) + input_dict[module_name] = _restore_tensor_shape(activation_tensor) return input_dict def recv_backward( @@ -389,12 +410,14 @@ def recv_backward( # If last stage, and has incoming modules, receive backward gradient # by using bridge communicator. for bridge_comm in rank_module_info.bridge_comms_as_src_module: - grad_dict[bridge_comm.src_module_name] = bridge_comm.recv_backward() + recv_grad_tensor = bridge_comm.recv_backward() + grad_dict[bridge_comm.src_module_name] = _restore_tensor_shape(recv_grad_tensor) else: # If not last stage, receive backward gradient by using P2P communicator. - grad_dict[module_name] = rank_module_info.p2p_communicator.recv_backward( + recv_grad_tensor = rank_module_info.p2p_communicator.recv_backward( tensor_shapes=tensor_shape, is_last_stage=False ) + grad_dict[module_name] = _restore_tensor_shape(recv_grad_tensor) return grad_dict def send_backward(self, grad_dict: Dict[str, torch.Tensor], is_first_stage: bool = False): @@ -412,12 +435,12 @@ def send_backward(self, grad_dict: Dict[str, torch.Tensor], is_first_stage: bool # If first stage, and has incoming modules, send backward activation # by using bridge communicator. for bridge_comm in rank_module_info.bridge_comms_as_dest_module: - bridge_comm.send_backward(grad_dict[bridge_comm.src_module_name]) + grad_to_send = _ensure_3d_tensor(grad_dict[bridge_comm.src_module_name]) + bridge_comm.send_backward(grad_to_send) else: # If not first stage, send backward activation by using P2P communicator. - rank_module_info.p2p_communicator.send_backward( - grad_dict[module_name], is_first_stage=False - ) + grad_to_send = _ensure_3d_tensor(grad_dict[module_name]) + rank_module_info.p2p_communicator.send_backward(grad_to_send, is_first_stage=False) @staticmethod def compute_total_pipeline_stages( diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py index ac839c21f1..86efebc4f0 100644 --- a/megatron/core/pipeline_parallel/p2p_communication.py +++ b/megatron/core/pipeline_parallel/p2p_communication.py @@ -7,6 +7,7 @@ import torch.distributed as dist from megatron.core.model_parallel_config import ModelParallelConfig +from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage from megatron.core.utils import nvtx_decorator # Types @@ -162,6 +163,21 @@ def __init__(self, pp_group: dist.ProcessGroup, config: ModelParallelConfig): else None ) + @property + def is_pp_first_stage(self): + """Return True if pp first stage.""" + return is_pp_first_stage(self.pp_group) + + @property + def is_pp_last_stage(self): + """Return True if pp last stage.""" + return is_pp_last_stage(self.pp_group) + + @property + def num_warmup_microbatches(self): + """Return number of warmup microbatches.""" + return self.pp_group.size() - self.pp_group.rank() - 1 + def _communicate_shapes(self, tensor_send_next, tensor_send_prev, recv_prev, recv_next): """Communicate tensor shapes between stages. Used to communicate tensor shapes before the actual tensor communication happens. diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 11e54e0fa5..fe7f42ae1a 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -2,7 +2,7 @@ import contextlib from functools import partial -from typing import Callable, Iterator, List, Optional, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Union import torch from torch.autograd.variable import Variable @@ -12,6 +12,7 @@ from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( fine_grained_offloading_reset, ) +from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicator from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator from megatron.core.pipeline_parallel.utils import ( is_pp_first_stage, @@ -135,7 +136,10 @@ def forward_step(data_iterator, model): return forward_backward_func -def deallocate_output_tensor(out, deallocate_pipeline_outputs=False): +def deallocate_output_tensor( + out: Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor]], + deallocate_pipeline_outputs=False, +): '''Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field. This method should be called right after the output tensor has been @@ -144,9 +148,19 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False): ''' if (out is None) or (not deallocate_pipeline_outputs): return - assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ - assert out._base is None, "counter-productive to free a view of another tensor." - out.data = torch.empty((1,), device=out.device, dtype=out.dtype) + if isinstance(out, torch.Tensor): + assert out._base is None, "counter-productive to free a view of another tensor." + out.data = torch.empty((1,), device=out.device, dtype=out.dtype) + + if isinstance(out, dict): + for v in out.values(): + deallocate_output_tensor(v, deallocate_pipeline_outputs) + return + + if isinstance(out, (list, tuple)): + for v in out: + deallocate_output_tensor(v, deallocate_pipeline_outputs) + return def custom_backward(output, grad_output): @@ -425,7 +439,7 @@ def forward_step( return [output_tensor], num_tokens -def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config): +def backward_step_tensor(input_tensor, output_tensor, output_tensor_grad, model_type, config): """Backward step through passed-in output tensor. If last stage, output_tensor_grad is None, otherwise gradient of loss @@ -489,6 +503,101 @@ def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, c return input_tensor_grad +def _backward_step_dict( + input_tensor: Dict[str, torch.Tensor], + output_tensor: Union[torch.Tensor, Dict[str, torch.Tensor]], + output_tensor_grad: Union[torch.Tensor, Dict[str, torch.Tensor]], + model_type: str, + config: Any, +): + """Backward step implementation when inputs/outputs are dictionaries (multi module case).""" + + if config.timers is not None: + config.timers('backward-compute', log_level=2).start() + + # Retain gradients on all input tensors + for module_name, tensor in input_tensor.items(): + if isinstance(tensor, list): + tensor = tensor[0] + if tensor is not None: + tensor.retain_grad() + + # Last stage: output_tensor is a scalar loss, wrap in dict for uniform handling + # Use the first input tensor key as the main module name + # now last stage only has one module LLM, + # we should add these asserts in top level train loop. + if not isinstance(output_tensor, dict): + all_keys = list(input_tensor.keys()) + assert len(all_keys) == 1, "Last stage only has one module - LLM" + main_module_key = all_keys[0] + output_tensor = {main_module_key: output_tensor} + + # Handle output_tensor_grad: None (last stage) or dict (intermediate stages) + if not output_tensor_grad: + # Last stage: no gradient from next stage + output_tensor_grad = {key: None for key in output_tensor.keys()} + + # Apply grad scaling if needed (for last stage) + for module_name in output_tensor.keys(): + if output_tensor_grad[module_name] is None and config.grad_scale_func is not None: + output_tensor[module_name] = config.grad_scale_func(output_tensor[module_name]) + + # Perform backward pass for each module + for module_name in output_tensor.keys(): + output_tensor_module = output_tensor[module_name] + output_tensor_grad_module = output_tensor_grad[module_name] + + # Skip backward if tensor doesn't require gradients + if output_tensor_module is not None and output_tensor_module.requires_grad: + if config.deallocate_pipeline_outputs: + custom_backward(output_tensor_module, output_tensor_grad_module) + else: + torch.autograd.backward( + output_tensor_module, grad_tensors=output_tensor_grad_module + ) + + # Collect gradients for input tensors + input_tensor_grad = {} + for module_name, tensor in input_tensor.items(): + if isinstance(tensor, list): + tensor = tensor[0] + if tensor is None: + input_tensor_grad[module_name] = None + else: + input_tensor_grad[module_name] = tensor.grad + + if config.timers is not None: + config.timers('backward-compute').stop() + + return input_tensor_grad + + +def backward_step( + input_tensor: Union[torch.Tensor, Dict[str, torch.Tensor]], + output_tensor: Union[torch.Tensor, Dict[str, torch.Tensor]], + output_tensor_grad: Union[torch.Tensor, Dict[str, torch.Tensor]], + model_type: str, + config: Any, +): + """Backward step wrapper support both tensor and dictionary formats + + The inputs and outputs are dictionaries for multimodule case. + The keys of the dictionaries are the module names. + + Returns: + The input tensor or dictionary of tensors. + """ + + if isinstance(input_tensor, dict): + return _backward_step_dict( + input_tensor, output_tensor, output_tensor_grad, model_type, config + ) + else: + return backward_step_tensor( + input_tensor, output_tensor, output_tensor_grad, model_type, config + ) + + def check_first_val_step(first_val_step, forward_only, cond): """Check if it is the first validation step.""" if (first_val_step is not None) and forward_only: @@ -1935,8 +2044,8 @@ def get_tensor_shapes( micro_batch_size: int, decoder_seq_length: int, config, - tp_group: torch.distributed.ProcessGroup, - cp_group: torch.distributed.ProcessGroup, + tp_group: Optional[torch.distributed.ProcessGroup] = None, + cp_group: Optional[torch.distributed.ProcessGroup] = None, ): """ Determine right tensor sizes (based on position of rank with respect to split rank) and @@ -1944,15 +2053,23 @@ def get_tensor_shapes( """ tensor_shapes = [] - # Use decoder_seq_length if provided, otherwise use seq_length - effective_seq_length = decoder_seq_length if decoder_seq_length is not None else seq_length - effective_seq_length = effective_seq_length // cp_group.size() + if config.variable_seq_lengths: + # with variable seq_lengths, ranks exchange the tensor shape with each other + tensor_shapes.append(()) + return tensor_shapes + else: + assert ( + tp_group is not None and cp_group is not None + ), "tp_group and cp_group must be provided" + # Use decoder_seq_length if provided, otherwise use seq_length + effective_seq_length = decoder_seq_length if decoder_seq_length is not None else seq_length + effective_seq_length = effective_seq_length // cp_group.size() - if config.sequence_parallel: - effective_seq_length = effective_seq_length // tp_group.size() + if config.sequence_parallel: + effective_seq_length = effective_seq_length // tp_group.size() - tensor_shapes.append((effective_seq_length, micro_batch_size, config.hidden_size)) - return tensor_shapes + tensor_shapes.append((effective_seq_length, micro_batch_size, config.hidden_size)) + return tensor_shapes def forward_backward_pipelining_without_interleaving( @@ -1968,8 +2085,8 @@ def forward_backward_pipelining_without_interleaving( collect_non_loss_data: bool = False, first_val_step: Optional[bool] = None, adjust_tensor_shapes_fn: Optional[Callable] = None, - p2p_communicator: Optional[P2PCommunicator] = None, - pg_collection: Optional[ProcessGroupCollection] = None, + p2p_communicator: Optional[Union[P2PCommunicator, MultiModulePipelineCommunicator]] = None, + pg_collection: Optional[Union[ProcessGroupCollection, List[ProcessGroupCollection]]] = None, ): """Run non-interleaved 1F1B schedule, with communication between pipeline stages. Returns dictionary with losses if the last stage, empty dict otherwise.""" @@ -1990,7 +2107,7 @@ def forward_backward_pipelining_without_interleaving( raise ValueError( "Non-interleaved pipeline parallelism does not support overlapping p2p communication" ) - + tp_group, cp_group, llm_cp_size = None, None, None if p2p_communicator is None and pg_collection is None: p2p_communicator = P2PCommunicator( pp_group=parallel_state.get_pipeline_model_parallel_group(), config=config @@ -2010,33 +2127,24 @@ def forward_backward_pipelining_without_interleaving( pg_collection.dp_cp = parallel_state.get_data_parallel_group( with_context_parallel=True, partial_data_parallel=False ) + llm_cp_size = cp_group.size() elif p2p_communicator is not None and pg_collection is not None: - model_type = get_model_type(model) - assert model_type != ModelType.encoder_and_decoder, ( - "encoder PP stages not yet supported when passing custom process groups. " - "support coming soon!" - ) - assert hasattr(p2p_communicator, 'config'), "p2p_communicator must have a config" - assert hasattr(pg_collection, 'tp'), "pg_collection must have tp_group" - assert hasattr(pg_collection, 'cp'), "pg_collection must have cp_group" - assert hasattr(pg_collection, 'embd'), ( - "pg_collection must have a embd. In previous version, it is used default " - "`parallel_state.default_embedding_ranks` to create the process group. " - " If you are using the default process group, please use " - " `parallel_state.get_embedding_group()` " - "If you don't need embd_group, you need to explicitly set it to None." - ) - assert hasattr(pg_collection, 'pos_embd'), ( - "pg_collection must have a pos_embd. In previous version, it is used default " - "`parallel_state.default_position_embedding_ranks` to create the process group. " - " If you are using the default process group, please use " - " `parallel_state.get_position_embedding_group()` " - "If you don't need pos_embd_group, you need to explicitly set it to None." - ) - assert hasattr(pg_collection, 'pp'), "pg_collection must have pp_group" - assert hasattr(pg_collection, 'dp_cp'), "pg_collection must have dp_cp_group" - tp_group = pg_collection.tp - cp_group = pg_collection.cp + if isinstance(pg_collection, list): + # cases when multiple modules are colocated + assert ( + config.variable_seq_lengths + ), "variable seq_lengths is required when multiple modules are colocated" + # when llm is colocated for now assume last collection in the list is the llm + # TODO: ykarnati: Have a better interface to handle this + # (without breaking backward compatibility) + assert hasattr(pg_collection[-1], 'cp'), "pg_collection must have cp_group" + llm_cp_size = pg_collection[-1].cp.size() + else: + assert hasattr(pg_collection, 'tp'), "pg_collection must have tp_group" + assert hasattr(pg_collection, 'cp'), "pg_collection must have cp_group" + tp_group = pg_collection.tp + cp_group = pg_collection.cp + llm_cp_size = pg_collection.cp.size() else: raise ValueError( "Invalid combination of p2p_communicator, pg_collection " @@ -2046,7 +2154,7 @@ def forward_backward_pipelining_without_interleaving( # Needed only when gradients are finalized in M-Core if config.finalize_model_grads_func is not None and not forward_only: embedding_module = clear_embedding_activation_buffer( - config, model, is_pp_last_stage(p2p_communicator.pp_group) + config, model, p2p_communicator.is_pp_last_stage ) if config.timers is not None: @@ -2078,9 +2186,7 @@ def enable_grad_sync(): disable_grad_sync() # Compute number of warmup microbatches. - num_warmup_microbatches = ( - p2p_communicator.pp_group.size() - p2p_communicator.pp_group.rank() - 1 - ) + num_warmup_microbatches = p2p_communicator.num_warmup_microbatches num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_microbatches_remaining = num_microbatches - num_warmup_microbatches @@ -2098,7 +2204,6 @@ def enable_grad_sync(): model_type = get_model_type(model) - rank = p2p_communicator.pp_group.rank() recv_tensor_shapes = get_tensor_shapes( seq_length=seq_length, micro_batch_size=micro_batch_size, @@ -2142,7 +2247,7 @@ def enable_grad_sync(): checkpoint_activations_microbatch = None input_tensor = p2p_communicator.recv_forward( - recv_tensor_shapes, is_pp_first_stage(p2p_communicator.pp_group) + recv_tensor_shapes, p2p_communicator.is_pp_first_stage ) output_tensor, num_tokens = forward_step( forward_step_func, @@ -2157,22 +2262,22 @@ def enable_grad_sync(): checkpoint_activations_microbatch=checkpoint_activations_microbatch, is_first_microbatch=check_first_val_step(first_val_step, forward_only, i == 0), current_microbatch=i, - is_last_stage=is_pp_last_stage(p2p_communicator.pp_group), + is_last_stage=p2p_communicator.is_pp_last_stage, ) - p2p_communicator.send_forward(output_tensor, is_pp_last_stage(p2p_communicator.pp_group)) + p2p_communicator.send_forward(output_tensor, p2p_communicator.is_pp_last_stage) total_num_tokens += num_tokens if not forward_only: input_tensors.append(input_tensor) output_tensors.append(output_tensor) - deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs) + deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) # Before running 1F1B, need to receive first forward tensor. # If all microbatches are run in warmup / cooldown phase, then no need to # receive this tensor here. if num_microbatches_remaining > 0: input_tensor = p2p_communicator.recv_forward( - recv_tensor_shapes, is_pp_first_stage(p2p_communicator.pp_group) + recv_tensor_shapes, p2p_communicator.is_pp_first_stage ) # Run 1F1B in steady state. @@ -2195,34 +2300,32 @@ def enable_grad_sync(): input_tensor, forward_data_store, config, - cp_group_size=pg_collection.cp.size(), + cp_group_size=llm_cp_size, collect_non_loss_data=collect_non_loss_data, checkpoint_activations_microbatch=checkpoint_activations_microbatch, is_first_microbatch=check_first_val_step( first_val_step, forward_only, (i == 0) and (num_warmup_microbatches == 0) ), current_microbatch=i + num_warmup_microbatches, - is_last_stage=is_pp_last_stage(p2p_communicator.pp_group), + is_last_stage=p2p_communicator.is_pp_last_stage, ) total_num_tokens += num_tokens if forward_only: - p2p_communicator.send_forward( - output_tensor, is_pp_last_stage(p2p_communicator.pp_group) - ) + p2p_communicator.send_forward(output_tensor, p2p_communicator.is_pp_last_stage) if not last_iteration: input_tensor = p2p_communicator.recv_forward( - recv_tensor_shapes, is_pp_first_stage(p2p_communicator.pp_group) + recv_tensor_shapes, p2p_communicator.is_pp_first_stage ) else: output_tensor_grad = p2p_communicator.send_forward_recv_backward( - output_tensor, send_tensor_shapes, is_pp_last_stage(p2p_communicator.pp_group) + output_tensor, send_tensor_shapes, p2p_communicator.is_pp_last_stage ) # Add input_tensor and output_tensor to end of list. input_tensors.append(input_tensor) output_tensors.append(output_tensor) - deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs) + deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) # Pop input_tensor and output_tensor from the start of the list for # the backward pass. @@ -2232,7 +2335,7 @@ def enable_grad_sync(): # Enable grad sync for the last microbatch in the batch if the full # backward pass completes in the 1F1B stage. if num_warmup_microbatches == 0 and last_iteration: - if config.grad_sync_func is None or rank == 0: + if config.grad_sync_func is None or p2p_communicator.is_pp_first_stage: enable_grad_sync() input_tensor_grad = backward_step( @@ -2242,13 +2345,11 @@ def enable_grad_sync(): if last_iteration: input_tensor = None p2p_communicator.send_backward( - input_tensor_grad, is_pp_first_stage(p2p_communicator.pp_group) + input_tensor_grad, p2p_communicator.is_pp_first_stage ) else: input_tensor = p2p_communicator.send_backward_recv_forward( - input_tensor_grad, - recv_tensor_shapes, - is_pp_first_stage(p2p_communicator.pp_group), + input_tensor_grad, recv_tensor_shapes, p2p_communicator.is_pp_first_stage ) # Run cooldown backward passes. @@ -2261,23 +2362,21 @@ def enable_grad_sync(): # pipeline stages do grad reduction during pipeline # bubble. if i == num_warmup_microbatches - 1: - if config.grad_sync_func is None or rank == 0: + if config.grad_sync_func is None or p2p_communicator.is_pp_first_stage: enable_grad_sync() input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) output_tensor_grad = p2p_communicator.recv_backward( - send_tensor_shapes, is_pp_last_stage(p2p_communicator.pp_group) + send_tensor_shapes, p2p_communicator.is_pp_last_stage ) input_tensor_grad = backward_step( input_tensor, output_tensor, output_tensor_grad, model_type, config ) - p2p_communicator.send_backward( - input_tensor_grad, is_pp_first_stage(p2p_communicator.pp_group) - ) + p2p_communicator.send_backward(input_tensor_grad, p2p_communicator.is_pp_first_stage) # Launch any remaining grad reductions. if no_sync_context is not None: @@ -2290,7 +2389,7 @@ def enable_grad_sync(): # If defer_embedding_wgrad_compute is enabled we need to do the # weight gradient GEMM's here. finish_embedding_wgrad_compute( - config, embedding_module, is_pp_last_stage(p2p_communicator.pp_group), tp_group + config, embedding_module, p2p_communicator.is_pp_last_stage, tp_group ) # Finalize model grads (perform full grad all-reduce / reduce-scatter for diff --git a/tests/unit_tests/pipeline_parallel/test_bridge_communicator.py b/tests/unit_tests/pipeline_parallel/test_bridge_communicator.py index 4b426b718e..eec2b3160f 100644 --- a/tests/unit_tests/pipeline_parallel/test_bridge_communicator.py +++ b/tests/unit_tests/pipeline_parallel/test_bridge_communicator.py @@ -1,3 +1,5 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + import logging import os import sys @@ -7,18 +9,12 @@ import torch.distributed as dist from packaging import version -from megatron.core import parallel_state +from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig from megatron.core.hyper_comm_grid import HyperCommGrid -from megatron.core.model_parallel_config import ModelParallelConfig from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec -from megatron.core.parallel_state import ( - get_context_parallel_group, - get_expert_model_parallel_rank, - get_tensor_model_parallel_rank, -) +from megatron.core.parallel_state import get_context_parallel_group, get_tensor_model_parallel_rank from megatron.core.pipeline_parallel.bridge_communicator import BridgeCommunicator from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_config import TransformerConfig @@ -108,19 +104,15 @@ def _shard_and_copy_( ) -def create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=1): +def create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=1, ep=1, etp=1): """Create a HyperCommGrid with tensor parallelism=2, context parallelism=2, and data parallelism=2.""" # Set up environment for world size 8 if not already set - if not dist.is_initialized(): - raise RuntimeError("Distributed process group is not initialized") - - # tests below assume a world size of 8 if "WORLD_SIZE" not in os.environ: os.environ["WORLD_SIZE"] = "8" grid = HyperCommGrid( - shape=[tp, cp, pp, dp], - dim_names=["tp", "cp", "pp", "dp"], + shape=[tp, cp, pp, dp, ep, etp], + dim_names=["tp", "cp", "pp", "dp", "ep", "etp"], rank_offset=offset, backend="nccl", ) @@ -128,6 +120,12 @@ def create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=1): _ = grid.create_pg(["cp"]) _ = grid.create_pg(["pp"]) _ = grid.create_pg(["dp"]) + _ = grid.create_pg(["ep"]) + _ = grid.create_pg(["tp", "pp"]) + _ = grid.create_pg(["dp", "cp"]) + _ = grid.create_pg(["tp", "cp"]) + _ = grid.create_pg(["tp", "dp", "cp"]) + _ = grid.create_pg(["tp", "ep", "pp"]) return grid @@ -136,6 +134,16 @@ def _get_pg_collection_from_grid(grid): pg_collection.tp = grid.get_pg("tp") pg_collection.cp = grid.get_pg("cp") pg_collection.pp = grid.get_pg("pp") + pg_collection.ep = grid.get_pg("ep") + dp_group = grid.get_pg("dp") + dp_cp_group = grid.get_pg(["dp", "cp"]) + pg_collection.dp = dp_group + pg_collection.dp_cp = dp_cp_group + pg_collection.mp = grid.get_pg(["tp", "pp"]) + pg_collection.dp_cp = grid.get_pg(["dp", "cp"]) + pg_collection.tp_cp = grid.get_pg(["tp", "cp"]) + pg_collection.tp_dp_cp = grid.get_pg(["tp", "dp", "cp"]) + pg_collection.tp_ep_pp = grid.get_pg(["tp", "ep", "pp"]) return pg_collection @@ -147,7 +155,7 @@ def _avg_params(module: torch.nn.Module, group: dist.ProcessGroup = None) -> Non def get_transformer_block_and_grid( - ref_block, + ref_block=None, tp_size=1, cp_size=1, pp_size=1, @@ -156,13 +164,15 @@ def get_transformer_block_and_grid( use_global_parallel_state: bool = False, hidden_size: int = 4096, dtype: torch.dtype = torch.bfloat16, + wrap_with_ddp: bool = False, ): """Utility to build a ``TransformerBlock`` for tests.""" current_rank = dist.get_rank() if use_global_parallel_state: block = _create_transformer_block(dtype=dtype, hidden_size=hidden_size) - _shard_and_copy_(ref_block, block, tp_size, get_tensor_model_parallel_rank()) + if ref_block is not None: + _shard_and_copy_(ref_block, block, tp_size, get_tensor_model_parallel_rank()) grid = None else: grid = create_hypercomm_grid( @@ -173,10 +183,20 @@ def get_transformer_block_and_grid( block = _create_transformer_block( dtype=dtype, hidden_size=hidden_size, pg_collection=pg_collection ) - _shard_and_copy_(ref_block, block, tp_size, pg_collection.tp.rank()) + if ref_block is not None: + _shard_and_copy_(ref_block, block, tp_size, pg_collection.tp.rank()) else: block = None + if wrap_with_ddp and block is not None: + ddp_config = DistributedDataParallelConfig(overlap_grad_reduce=True, bucket_size=10000) + block = DistributedDataParallel( + config=block.config, ddp_config=ddp_config, module=block, pg_collection=pg_collection + ) + block.pre_process = False + block.post_process = False + block.share_embeddings_and_output_weights = False + return block, grid diff --git a/tests/unit_tests/pipeline_parallel/test_multimodule_communicator.py b/tests/unit_tests/pipeline_parallel/test_multimodule_communicator.py index 73739859f4..f1b00f9e0b 100644 --- a/tests/unit_tests/pipeline_parallel/test_multimodule_communicator.py +++ b/tests/unit_tests/pipeline_parallel/test_multimodule_communicator.py @@ -1,14 +1,11 @@ -import logging +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. import os -import sys import pytest import torch import torch.distributed as dist from packaging import version -from megatron.core import parallel_state -from megatron.core.hyper_comm_grid import HyperCommGrid from megatron.core.model_parallel_config import ModelParallelConfig from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicator from tests.unit_tests.pipeline_parallel.test_bridge_communicator import ( diff --git a/tests/unit_tests/pipeline_parallel/test_multimodule_schedules.py b/tests/unit_tests/pipeline_parallel/test_multimodule_schedules.py new file mode 100644 index 0000000000..b801a353cf --- /dev/null +++ b/tests/unit_tests/pipeline_parallel/test_multimodule_schedules.py @@ -0,0 +1,366 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +import logging +import os +from contextlib import contextmanager +from typing import Dict, List + +import pytest +import torch +import torch.distributed as dist +from packaging import version + +import megatron.core.pipeline_parallel.schedules as schedule +from megatron.core import ModelParallelConfig +from megatron.core.distributed.finalize_model_grads import finalize_model_grads +from megatron.core.hyper_comm_grid import HyperCommGrid +from megatron.core.model_parallel_config import ModelParallelConfig +from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicator +from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage +from tests.unit_tests.pipeline_parallel.test_bridge_communicator import ( + _get_pg_collection_from_grid, + get_transformer_block_and_grid, +) +from tests.unit_tests.pipeline_parallel.test_schedules import ( + _populate_embedding_and_position_groups, +) +from tests.unit_tests.test_utilities import Utils + +rank = Utils.rank +logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') + + +class DataIterator: + + def __init__(self, hidden_size: int, seq_length: int, micro_batch_size: int): + self.hidden_size = hidden_size + self.seq_length = seq_length + self.micro_batch_size = micro_batch_size + + def __iter__(self): + return self + + def __next__(self): + return torch.randn( + self.seq_length, + self.micro_batch_size, + self.hidden_size, + device='cuda', + dtype=torch.bfloat16, + ) + + +class SingleEncoderModel(torch.nn.Module): + def __init__( + self, + hidden_size, + encoder_tp, + encoder_pp, + encoder_dp, + llm_tp, + llm_pp, + llm_dp, + llm_grid_offset, + ): + + super().__init__() + + self.encoder, self.encoder_grid = get_transformer_block_and_grid( + tp_size=encoder_tp, + cp_size=1, + pp_size=encoder_pp, + dp_size=encoder_dp, + hidden_size=hidden_size, + wrap_with_ddp=True, + ) + + self.llm, self.llm_grid = get_transformer_block_and_grid( + tp_size=llm_tp, + cp_size=1, + pp_size=llm_pp, + dp_size=llm_dp, + grid_offset=llm_grid_offset, + hidden_size=hidden_size, + wrap_with_ddp=True, + ) + + # Simple list for iteration + self.modules_and_grids = [(self.encoder, self.encoder_grid), (self.llm, self.llm_grid)] + + self.current_rank = dist.get_rank() + self.encoder_input_tensor = None + self.llm_input_tensor = None + + def finish_grad_sync(self): + """Finish gradient synchronization for all active modules on this rank.""" + for module, grid in self.modules_and_grids: + if module is not None and self.is_current_rank_in_grid(grid): + module.finish_grad_sync() + + @contextmanager + def no_sync(self): + contexts = [] + if self.is_current_rank_in_grid(self.encoder_grid): + contexts.append(self.encoder.no_sync()) + if self.is_current_rank_in_grid(self.llm_grid): + contexts.append(self.llm.no_sync()) + + # Enter all contexts + for ctx in contexts: + ctx.__enter__() + + try: + yield + finally: + # Exit all contexts in reverse order + for ctx in reversed(contexts): + ctx.__exit__(None, None, None) + + @property + def ddp_config(self): + # Try to get ddp_config from the first available module on this rank + for module, grid in self.modules_and_grids: + if module is not None and self.is_current_rank_in_grid(grid): + return module.ddp_config + raise AttributeError(f"No active modules with ddp_config found on rank {self.current_rank}") + + def scale_gradients(self, scaling_factor: float): + """Scale gradients for all active modules on this rank.""" + for module, grid in self.modules_and_grids: + if module is not None and self.is_current_rank_in_grid(grid): + module.scale_gradients(scaling_factor) + + def is_current_rank_in_grid(self, grid: HyperCommGrid) -> bool: + """Check if the current rank is in the grid.""" + return grid.rank_offset <= self.current_rank < (grid.rank_offset + grid.size) + + def finalize_model_grads(self, module=None, num_tokens=None, pg_collection=None): + for module, grid in self.modules_and_grids: + if module is not None and self.is_current_rank_in_grid(grid): + finalize_model_grads( + [module], + num_tokens=None, + pg_collection=_get_pg_collection_with_embedding_groups(grid), + ) + + @contextmanager + def no_sync(self): + contexts = [] + for module, grid in self.modules_and_grids: + if module is not None and self.is_current_rank_in_grid(grid): + contexts.append(module.no_sync()) + + # Enter all contexts + for ctx in contexts: + ctx.__enter__() + + try: + yield + finally: + # Exit all contexts in reverse order + for ctx in reversed(contexts): + ctx.__exit__(None, None, None) + + def set_input_tensor(self, input_tensor: List[Dict[str, torch.Tensor]]): + if self.is_current_rank_in_grid(self.encoder_grid) and 'encoder' in input_tensor[0]: + if isinstance(input_tensor[0]["encoder"], list): + encoder_input_tensor = input_tensor[0]["encoder"][0] + else: + encoder_input_tensor = input_tensor[0]["encoder"] + logging.debug( + f"[Rank {dist.get_rank()} ][SingleEncoderModel] [set_input_tensor] [encoder] input tensor shape: {input_tensor[0]['encoder'][0].shape}" + ) + self.encoder_input_tensor = encoder_input_tensor + elif self.is_current_rank_in_grid(self.llm_grid): + if 'llm' in input_tensor[0]: + if isinstance(input_tensor[0]["llm"], list): + llm_input_tensor = input_tensor[0]["llm"][0] + else: + llm_input_tensor = input_tensor[0]["llm"] + logging.debug( + f"[Rank {dist.get_rank()} ][SingleEncoderModel] [set_input_tensor] [llm] input tensor shape: {llm_input_tensor.shape}" + ) + self.llm_input_tensor = llm_input_tensor + elif 'encoder' in input_tensor[0]: + logging.debug( + f"[Rank {dist.get_rank()} ][SingleEncoderModel] [set_input_tensor] [encoder] input tensor shape: {input_tensor[0]['encoder'].shape}" + ) + self.llm_input_tensor = input_tensor[0]["encoder"] + else: + raise ValueError(f"Rank {dist.get_rank()} is not valid") + + def forward(self, hidden_states): + + current_rank = dist.get_rank() + output_dict = {} + if self.is_current_rank_in_grid(self.encoder_grid): + # if pp rank > 0 in encoder pp group then we use self.encoder_input_tensor as input else we use hidden_states + if is_pp_first_stage(self.encoder_grid.get_pg("pp")): + input_tensor = hidden_states + else: + assert ( + self.encoder_input_tensor is not None + ), "Encoder input tensor is not provided for pp rank > 0" + input_tensor = self.encoder_input_tensor + logging.debug( + f"[Rank {dist.get_rank()} ][SingleEncoderModel] [forward] [encoder] input tensor shape: {input_tensor.shape}" + ) + output_dict["encoder"] = self.encoder(input_tensor, attention_mask=None) + elif self.is_current_rank_in_grid(self.llm_grid): + assert ( + self.llm_input_tensor is not None + ), "LLM input tensor is not provided for pp rank > 0" + input_tensor = self.llm_input_tensor + logging.debug( + f"[Rank {dist.get_rank()} ][SingleEncoderModel] [forward] [llm] input tensor shape: {input_tensor.shape}" + ) + output_dict["llm"] = self.llm(input_tensor, attention_mask=None) + else: + raise ValueError(f"Rank {current_rank} is not valid") + + return output_dict + + +def _get_pg_collection_with_embedding_groups(grid): + pg_collection = _get_pg_collection_from_grid(grid) + if pg_collection.pp: + pos_embd_pg, embd_pg = _populate_embedding_and_position_groups(pg_collection.pp) + pos_embd_pg = pos_embd_pg if is_pp_first_stage(pg_collection.pp) else None + embd_pg = ( + embd_pg + if (is_pp_last_stage(pg_collection.pp) or is_pp_first_stage(pg_collection.pp)) + else None + ) + pg_collection.pos_embd = pos_embd_pg + pg_collection.embd = embd_pg + + return pg_collection + + +@pytest.mark.skipif( + version.parse(torch.__version__) < version.parse('2.3.0'), + reason="Device mesh feature requires PyTorch 2.3 or later", +) +@pytest.mark.parametrize( + "encoder_tp,encoder_pp,encoder_dp,llm_tp,llm_pp,llm_dp,llm_grid_offset", + [ + (2, 2, 1, 2, 2, 1, 4), + (4, 1, 1, 2, 2, 1, 4), + (2, 1, 1, 1, 6, 1, 2), + (2, 2, 1, 1, 4, 1, 4), + (2, 1, 2, 1, 1, 4, 4), + (2, 1, 2, 2, 2, 1, 4), + ], +) +def test_forward_backward_pipelining_without_interleaving_multi_module_single_encoder( + encoder_tp, encoder_pp, encoder_dp, llm_tp, llm_pp, llm_dp, llm_grid_offset +): + + Utils.initialize_distributed() + + def step_func(data_iterator, model): + + def loss_func(output_tensor_dict: Dict[str, torch.Tensor]): + assert ( + 'llm' in output_tensor_dict + ), f'llm is not in output_tensor_dict: {output_tensor_dict}' + loss = output_tensor_dict['llm'].sum() + return loss, {'loss_reduced': loss} + + if data_iterator is not None: + input_tensor = next(data_iterator) + else: + input_tensor = None + + model_output = model(input_tensor) + + return model_output, loss_func + + sequence_length = 512 + micro_batch_size = 4 + hidden_size = 1024 + + # Create model + model = SingleEncoderModel( + hidden_size=hidden_size, + encoder_tp=encoder_tp, + encoder_pp=encoder_pp, + encoder_dp=encoder_dp, + llm_tp=llm_tp, + llm_pp=llm_pp, + llm_dp=llm_dp, + llm_grid_offset=llm_grid_offset, + ) + model.model_type = 'unit-test' + + module_to_grid_map = {'encoder': model.encoder_grid, 'llm': model.llm_grid} + topology = { + 'encoder': ['llm'], # image_encoder sends forward results to llm + 'llm': [], # llm is the last stage here + } + config = ModelParallelConfig(pipeline_dtype=torch.bfloat16) + config.calculate_per_token_loss = False + config.qk_layernorm = False + config.sequence_parallel = False + config.moe_router_enable_expert_bias = False + config.moe_router_load_balancing_type = "aux_loss" + config.variable_seq_lengths = True + config.batch_p2p_comm = False + config.no_sync_func = model.no_sync + config.finalize_model_grads_func = model.finalize_model_grads + config.fine_grained_activation_offloading = False + + # Add grad scale function to convert float losses to tensors + def grad_scale_func(loss): + if isinstance(loss, (int, float)): + return torch.tensor(loss, dtype=torch.float32, device='cuda', requires_grad=True) + else: + return loss # Already a tensor + + config.grad_scale_func = grad_scale_func + model.config = config + config.hidden_size = hidden_size + + multimodule_communicator = MultiModulePipelineCommunicator( + module_to_grid_map, topology, config, dim_mapping={'s': 0, 'h': 2, 'b': 1} + ) + + data_iterator = None + if model.is_current_rank_in_grid(model.encoder_grid) and is_pp_first_stage( + model.encoder_grid.get_pg("pp") + ): + data_iterator = DataIterator( + hidden_size=hidden_size, seq_length=sequence_length, micro_batch_size=micro_batch_size + ) + + common_args = { + 'forward_step_func': step_func, + 'data_iterator': data_iterator, + 'model': [model], + 'num_microbatches': 16, + 'seq_length': sequence_length, + 'micro_batch_size': micro_batch_size, + 'forward_only': False, + } + + if 0 <= dist.get_rank() < llm_grid_offset: + pg_collection = _get_pg_collection_with_embedding_groups(model.encoder_grid) + elif llm_grid_offset <= dist.get_rank() < llm_grid_offset + model.llm_grid.size: + pg_collection = _get_pg_collection_with_embedding_groups(model.llm_grid) + else: + raise ValueError(f"Rank {dist.get_rank()} is not valid") + + losses_reduced_explicit = schedule.forward_backward_pipelining_without_interleaving( + p2p_communicator=multimodule_communicator, pg_collection=pg_collection, **common_args + ) + logging.info(f"Losses reduced explicit: {losses_reduced_explicit}") + + +if __name__ == "__main__": + + logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') + + # Use the same parameters as defined in the pytest.mark.parametrize decorator + test_forward_backward_pipelining_without_interleaving_multi_module_single_encoder( + encoder_tp=2, encoder_pp=1, encoder_dp=2, llm_tp=1, llm_pp=1, llm_dp=4, llm_grid_offset=4 + )