Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2163,8 +2163,9 @@ def set_save_original_input(module):

try:
# pylint: disable=unused-import
from transformer_engine.pytorch import cpu_offload
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch import cpu_offload_v1 as cpu_offload
except ImportError:
Float8Tensor = None
cpu_offload = None
try:
from transformer_engine.pytorch import cpu_offload
except ImportError:
cpu_offload = None
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import torch

from megatron.core.pipeline_parallel.utils import set_ideal_affinity_for_current_gpu

# CPU offload implementation for pipeline parallelism
DEBUG = False
DEBUG_RANK = 0
Expand All @@ -22,39 +24,6 @@ def debug_rank(message):
print(message)


def set_ideal_affinity_for_current_gpu():
"""Set CPU affinity for the current GPU to optimize host-device transfers."""
import uuid

try:
import cuda.bindings.driver as cuda_driver
import cuda.bindings.runtime as cuda_runtime
except ImportError:
try:
import cuda.cuda as cuda_driver
import cuda.cudart as cuda_runtime
except ImportError:
# print("cuda-python may not be installed, skipping GPU affinity setting")
warnings.warn("cuda-python may not be installed, skipping GPU affinity setting")
return
try:
import pynvml
except ImportError:
warnings.warn("pynvml is not installed, skipping GPU affinity setting")
return

# Get current CUDA device ID
err, device_id = cuda_runtime.cudaGetDevice()
assert err == cuda_runtime.cudaError_t.cudaSuccess
# Get device UUID
err, device_uuid = cuda_driver.cuDeviceGetUuid(device_id)
assert err == cuda_driver.CUresult.CUDA_SUCCESS
# Set CPU affinity based on GPU's NUMA node
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByUUID("GPU-" + str(uuid.UUID(bytes=device_uuid.bytes)))
pynvml.nvmlDeviceSetCpuAffinity(handle)


class PipelineOffloadManager:
"""
Singleton manager for coordinating activation offloading across pipeline stages.
Expand Down Expand Up @@ -200,6 +169,8 @@ def __enter__(self):

if cpu_offload is not None:
cpu_offload.CPUOffloadEnabled = True
else:
raise RuntimeError("TE CPU offload is not available")
self.inside_context = True

torch._C._autograd._push_saved_tensors_default_hooks(
Expand All @@ -213,6 +184,8 @@ def __exit__(self, *args: Any):

if cpu_offload is not None:
cpu_offload.CPUOffloadEnabled = False
else:
raise RuntimeError("TE CPU offload is not available")
self.inside_context = False
torch._C._autograd._pop_saved_tensors_default_hooks()

Expand Down Expand Up @@ -244,24 +217,18 @@ class ChunkOffloadHandler:
def offload(src_tensor, pin_memory=True):
"""Offload."""
debug_rank("--------offload")
from megatron.core.extensions.transformer_engine import Float8Tensor

fp8_offload = isinstance(src_tensor, Float8Tensor) if Float8Tensor is not None else False

if not src_tensor.is_contiguous():
src_tensor = src_tensor.contiguous()

cpu_backup = torch.empty(
src_tensor.size(),
dtype=torch.uint8 if fp8_offload else src_tensor.dtype,
dtype=src_tensor.dtype,
layout=src_tensor.layout,
device="cpu",
pin_memory=pin_memory,
)

if fp8_offload:
cpu_backup = Float8Tensor.make_like(src_tensor, data=cpu_backup)

cpu_backup.copy_(src_tensor, non_blocking=pin_memory)
state = (src_tensor.device, cpu_backup)
return state
Expand Down
33 changes: 33 additions & 0 deletions megatron/core/pipeline_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,39 @@ def make_viewless(e):
return e


def set_ideal_affinity_for_current_gpu():
"""Set CPU affinity for the current GPU to optimize host-device transfers."""
import uuid

try:
import cuda.bindings.driver as cuda_driver
import cuda.bindings.runtime as cuda_runtime
except ImportError:
try:
import cuda.cuda as cuda_driver
import cuda.cudart as cuda_runtime
except ImportError:
# print("cuda-python may not be installed, skipping GPU affinity setting")
warnings.warn("cuda-python may not be installed, skipping GPU affinity setting")
return
try:
import pynvml
except ImportError:
warnings.warn("pynvml is not installed, skipping GPU affinity setting")
return

# Get current CUDA device ID
err, device_id = cuda_runtime.cudaGetDevice()
assert err == cuda_runtime.cudaError_t.cudaSuccess
# Get device UUID
err, device_uuid = cuda_driver.cuDeviceGetUuid(device_id)
assert err == cuda_driver.CUresult.CUDA_SUCCESS
# Set CPU affinity based on GPU's NUMA node
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByUUID("GPU-" + str(uuid.UUID(bytes=device_uuid.bytes)))
pynvml.nvmlDeviceSetCpuAffinity(handle)


@contextmanager
def stream_acquire_context(stream, event):
"""Stream acquire context"""
Expand Down