Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions megatron/core/pipeline_parallel/bridge_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def recv_backward(self) -> torch.Tensor:
received_gradients_list.append(grad_tensor)

# Concatenate received gradients
aggregated_gradient = torch.cat(received_gradients_list, dim=0)
aggregated_gradient = torch.cat(received_gradients_list, dim=self.dim_mapping['b'])
logging.debug(
f"[Bridge Communicator] [receive_backward] Rank {self.current_rank} "
f"agg grad shape {aggregated_gradient.shape} sum {aggregated_gradient.sum()}"
Expand Down Expand Up @@ -615,7 +615,7 @@ def send_forward_recv_backward(
req.wait()

# Concatenate received gradients
aggregated_gradient = torch.cat(received_gradients_list, dim=0)
aggregated_gradient = torch.cat(received_gradients_list, dim=self.dim_mapping['b'])
logging.debug(
f"[Bridge Communicator] [send_forward_recv_backward] Rank {self.current_rank} "
f"agg grad shape {aggregated_gradient.shape} sum {aggregated_gradient.sum()}"
Expand Down Expand Up @@ -737,7 +737,9 @@ def send_backward_recv_forward(
req.wait()

# Concatenate received activations
aggregated_activation = torch.cat(received_activations_list, dim=0)
aggregated_activation = torch.cat(
received_activations_list, dim=self.dim_mapping['b']
)
logging.debug(
f"[Bridge Communicator] [send_backward_recv_backward] Rank {self.current_rank} "
f"agg act shape {aggregated_activation.shape} sum {aggregated_activation.sum()}"
Expand Down
83 changes: 53 additions & 30 deletions megatron/core/pipeline_parallel/multimodule_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,32 @@
Shape = Union[List[int], torch.Size]


def _ensure_3d_tensor(tensor):
"""Ensure tensor is 3D for P2P/bridge communication.

P2P and bridge communicators expect 3D tensors.
Handles both single tensors and lists of tensors (for VPP).
"""
if isinstance(tensor, list):
return [_ensure_3d_tensor(t) for t in tensor]
if isinstance(tensor, torch.Tensor) and tensor.ndim == 2:
return tensor.unsqueeze(-1)
return tensor
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure that the tensor must be 2D or 3D? Should we catch some exception here?



def _restore_tensor_shape(tensor):
"""Restore original tensor shape after P2P/bridge communication.

Remove the extra dimension added by _ensure_3d_tensor if it was singleton.
Handles both single tensors and lists of tensors (for VPP).
"""
if isinstance(tensor, list):
return [_restore_tensor_shape(t) for t in tensor]
if isinstance(tensor, torch.Tensor) and tensor.ndim == 3 and tensor.shape[-1] == 1:
return tensor.squeeze(-1)
return tensor


@dataclass
class RankModuleInfo:
"""Information about a rank in a module."""
Expand Down Expand Up @@ -281,12 +307,12 @@ def send_forward(self, output_dict: Dict[str, torch.Tensor], is_last_stage: bool
# If last stage, and has outgoing modules, send forward activation
# by using bridge communicator.
for bridge_comm in rank_module_info.bridge_comms_as_src_module:
bridge_comm.send_forward(output_dict[module_name])
tensor_to_send = _ensure_3d_tensor(output_dict[module_name])
bridge_comm.send_forward(tensor_to_send)
else:
# If not last stage, send forward activation by using P2P communicator.
rank_module_info.p2p_communicator.send_forward(
output_dict[module_name], is_last_stage=False
)
tensor_to_send = _ensure_3d_tensor(output_dict[module_name])
rank_module_info.p2p_communicator.send_forward(tensor_to_send, is_last_stage=False)

def send_forward_recv_backward(
self,
Expand All @@ -303,28 +329,23 @@ def send_forward_recv_backward(
Returns:
A dictionary mapping module names to tensors.
"""
logging.debug(
f"[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] "
f"[send_forward_recv_backward] output_dict keys: {output_dict.keys()}, "
f"tensor_shape: {tensor_shape}, is_last_stage: {is_last_stage}"
)
grad_dict = {}
for module_name, rank_module_info in self.rank_module_map.items():
if rank_module_info.pp_rank == rank_module_info.pp_size - 1:
# If last stage, and has outgoing modules, send forward activation and
# receive backward gradient by using bridge communicator.
for bridge_comm in rank_module_info.bridge_comms_as_src_module:
grad_dict[bridge_comm.src_module_name] = bridge_comm.send_forward_recv_backward(
output_dict[module_name]
)
tensor_to_send = _ensure_3d_tensor(output_dict[module_name])
grad_tensor = bridge_comm.send_forward_recv_backward(tensor_to_send)
grad_dict[bridge_comm.src_module_name] = _restore_tensor_shape(grad_tensor)
else:
# If not last stage, send forward activation and receive backward gradient
# by using P2P communicator.
grad_dict[module_name] = (
rank_module_info.p2p_communicator.send_forward_recv_backward(
output_dict[module_name], tensor_shapes=tensor_shape, is_last_stage=False
)
tensor_to_send = _ensure_3d_tensor(output_dict[module_name])
grad_tensor = rank_module_info.p2p_communicator.send_forward_recv_backward(
tensor_to_send, tensor_shapes=tensor_shape, is_last_stage=False
)
grad_dict[module_name] = _restore_tensor_shape(grad_tensor)
return grad_dict

def send_backward_recv_forward(
Expand Down Expand Up @@ -353,19 +374,19 @@ def send_backward_recv_forward(
for bridge_comm in rank_module_info.bridge_comms_as_dest_module:
# If first stage, and has incoming modules, send backward gradient and
# receive forward activation by using bridge communicator.
input_dict[bridge_comm.src_module_name] = (
bridge_comm.send_backward_recv_forward(
grad_dict[bridge_comm.src_module_name]
)
grad_to_send = _ensure_3d_tensor(grad_dict[bridge_comm.src_module_name])
activation_tensor = bridge_comm.send_backward_recv_forward(grad_to_send)
input_dict[bridge_comm.src_module_name] = _restore_tensor_shape(
activation_tensor
)
else:
# If not first stage, send backward gradient and receive forward activation
# by using P2P communicator.
input_dict[module_name] = (
rank_module_info.p2p_communicator.send_backward_recv_forward(
grad_dict[module_name], tensor_shapes=tensor_shape, is_first_stage=False
)
grad_to_send = _ensure_3d_tensor(grad_dict[module_name])
activation_tensor = rank_module_info.p2p_communicator.send_backward_recv_forward(
grad_to_send, tensor_shapes=tensor_shape, is_first_stage=False
)
input_dict[module_name] = _restore_tensor_shape(activation_tensor)
return input_dict

def recv_backward(
Expand All @@ -389,12 +410,14 @@ def recv_backward(
# If last stage, and has incoming modules, receive backward gradient
# by using bridge communicator.
for bridge_comm in rank_module_info.bridge_comms_as_src_module:
grad_dict[bridge_comm.src_module_name] = bridge_comm.recv_backward()
recv_grad_tensor = bridge_comm.recv_backward()
grad_dict[bridge_comm.src_module_name] = _restore_tensor_shape(recv_grad_tensor)
else:
# If not last stage, receive backward gradient by using P2P communicator.
grad_dict[module_name] = rank_module_info.p2p_communicator.recv_backward(
recv_grad_tensor = rank_module_info.p2p_communicator.recv_backward(
tensor_shapes=tensor_shape, is_last_stage=False
)
grad_dict[module_name] = _restore_tensor_shape(recv_grad_tensor)
return grad_dict

def send_backward(self, grad_dict: Dict[str, torch.Tensor], is_first_stage: bool = False):
Expand All @@ -412,12 +435,12 @@ def send_backward(self, grad_dict: Dict[str, torch.Tensor], is_first_stage: bool
# If first stage, and has incoming modules, send backward activation
# by using bridge communicator.
for bridge_comm in rank_module_info.bridge_comms_as_dest_module:
bridge_comm.send_backward(grad_dict[bridge_comm.src_module_name])
grad_to_send = _ensure_3d_tensor(grad_dict[bridge_comm.src_module_name])
bridge_comm.send_backward(grad_to_send)
else:
# If not first stage, send backward activation by using P2P communicator.
rank_module_info.p2p_communicator.send_backward(
grad_dict[module_name], is_first_stage=False
)
grad_to_send = _ensure_3d_tensor(grad_dict[module_name])
rank_module_info.p2p_communicator.send_backward(grad_to_send, is_first_stage=False)

@staticmethod
def compute_total_pipeline_stages(
Expand Down
16 changes: 16 additions & 0 deletions megatron/core/pipeline_parallel/p2p_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.distributed as dist

from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage
from megatron.core.utils import nvtx_decorator

# Types
Expand Down Expand Up @@ -162,6 +163,21 @@ def __init__(self, pp_group: dist.ProcessGroup, config: ModelParallelConfig):
else None
)

@property
def is_pp_first_stage(self):
"""Return True if pp first stage."""
return is_pp_first_stage(self.pp_group)

@property
def is_pp_last_stage(self):
"""Return True if pp last stage."""
return is_pp_last_stage(self.pp_group)

@property
def num_warmup_microbatches(self):
"""Return number of warmup microbatches."""
return self.pp_group.size() - self.pp_group.rank() - 1

def _communicate_shapes(self, tensor_send_next, tensor_send_prev, recv_prev, recv_next):
"""Communicate tensor shapes between stages. Used to communicate
tensor shapes before the actual tensor communication happens.
Expand Down
Loading
Loading