Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,304 changes: 1,304 additions & 0 deletions graph.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion run_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ set -ex
# LOG_RANK=0,1 NGPU=4 ./run_train.sh
NGPU=${NGPU:-"8"}
export LOG_RANK=${LOG_RANK:-0}
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"}
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/llama3_8b.toml"}

overrides=""
if [ $# -ne 0 ]; then
Expand Down
7 changes: 7 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,12 @@ class Experimental:
"""


@dataclass
class Offloading:
mode: Literal["multistream", "sequential", "none"] = "none"
offload_ratio: float = 0.0


@dataclass
class JobConfig:
"""
Expand All @@ -568,6 +574,7 @@ class JobConfig:
activation_checkpoint: ActivationCheckpoint = field(
default_factory=ActivationCheckpoint
)
offloading: Offloading = field(default_factory=Offloading)
float8: Float8 = field(default_factory=Float8)
mx: MX = field(default_factory=MX)
comm: Comm = field(default_factory=Comm)
Expand Down
3 changes: 2 additions & 1 deletion torchtitan/models/llama3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.


import contextlib
from dataclasses import dataclass

import torch
Expand Down Expand Up @@ -506,6 +507,6 @@ def from_model_args(cls, model_args: TransformerModelArgs) -> "Transformer":

Returns:
Transformer: Transformer model.

"""

return cls(model_args)
29 changes: 29 additions & 0 deletions torchtitan/models/llama3/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
# This file applies the PT-D parallelisms (except pipeline parallelism) and various
# training techniques (e.g. activation checkpointing and compile) to the Llama model.

import contextlib
import functools
from collections import defaultdict

import torch
Expand All @@ -30,6 +32,7 @@
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import ParallelDims
from torchtitan.tools.logging import logger
from torchtitan.offloading import activation_offload_with_overlap


def parallelize_llama(
Expand Down Expand Up @@ -75,6 +78,9 @@ def parallelize_llama(
if job_config.activation_checkpoint.mode != "none":
apply_ac(model, job_config.activation_checkpoint)

if job_config.offloading.mode == "multistream":
apply_ao(model, job_config.offloading)

# turn on per-TransformerBlock compile after AC wrapping and before FSDP
if job_config.training.compile:
apply_compile(model)
Expand Down Expand Up @@ -298,6 +304,29 @@ def apply_ac(model: nn.Module, ac_config):
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")


def _apply_act_offloading_to_transformer_block(module: nn.Module, ao_config):
offload_context = contextlib.nullcontext
if ao_config.mode == "multistream":
offload_context = functools.partial(activation_offload_with_overlap, module, ao_config.offload_ratio)

original_forward = module.forward
def new_forward(*args, **kwargs):
with offload_context():
return original_forward(*args, **kwargs)

module.forward = new_forward
return module


def apply_ao(model: nn.Module, ao_config):
"""Apply multistream activation offloading to the model"""
for layer_id, transformer_block in model.layers.named_children():
transformer_block = _apply_act_offloading_to_transformer_block(transformer_block, ao_config)
model.layers.register_module(layer_id, transformer_block)

logger.info(f"Applied {ao_config.mode} activation offloading to the model")


def apply_compile(model: nn.Module):
"""
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
Expand Down
5 changes: 5 additions & 0 deletions torchtitan/models/llama3/train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ description = "Llama 3 8B training"

[profiling]
enable_profiling = true
enable_memory_snapshot = true
save_traces_folder = "profile_trace"
profile_freq = 100

Expand Down Expand Up @@ -56,6 +57,10 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
mode = "selective" # ["none", "selective", "full"]
selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy

[offloading]
mode = "multistream" # ["multistream", "sequential", "none"]
offload_ratio = 1.0

[float8]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
Expand Down
187 changes: 187 additions & 0 deletions torchtitan/offloading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import torch
import logging
import random
from torch.nn import Module
from torch.autograd.graph import saved_tensors_hooks
from typing import NamedTuple
from collections import defaultdict


logger = logging.getLogger(__name__)


class PackInfo(NamedTuple):
# Record an event in the offload stream for the default stream to wait on
# before freeing the device tensor
d2h_event: torch.cuda.Event
# Keep a ref to the device tensor until the event has been waited on
device_tensor: torch.Tensor


class UnpackInfo(NamedTuple):
# Record an event during preallocation for the offload stream to wait on
# before copying to the device tensor
prealloc_event: torch.cuda.Event
# Preallocate the device tensor memory so it can be allocated in the
# default stream (instead of offload stream) to avoid fragmentation
device_tensor: torch.Tensor


# TODO: Remove these from global namespace and register on modules. Consider
# using module state as identifier instead of int ID.
# Used or overlapping H2D/D2H copy with compute
offload_stream: torch.cuda.Stream = torch.cuda.Stream()
# Used for module ordering
module_id_to_module: dict[int, Module] = {}
next_module_id = 0
# Used in forward to keep device tensors alive through D2H copies
module_to_pack_infos: dict[Module, list[PackInfo]] = defaultdict(list)
# Appended to in forward and used in backward to know which CPU tensors will be
# copied H2D in backward to preallocate their device memory
module_to_cpu_tensors: dict[Module, list[torch.Tensor]] = defaultdict(list)
# Used in backward to preallocate device tensors in the default stream
cpu_tensor_to_unpack_info: dict[torch.Tensor, UnpackInfo] = {}


class activation_offload_with_overlap(saved_tensors_hooks):
"""
In forward, we overlap the current module's D2H copies with the next
module's forward compute.

In backward, we overlap the current module's H2D copies with the previous
module's backward compute.

In backward, since we need to allocate new device memory for the H2D
destinations, we can either do so in the offload stream or in the default
stream. Naively, we may do so in the offload stream, but this fragments the
memory pool since memory blocks are not shared across streams. As such, we
instead choose to do so in the default stream. This requires preallocation
and a CUDA event to ensure that the H2D copy does not start too early,
using the default stream memory before it should.

"""

def __init__(self, module: Module, offload_ratio: float = 1.0) -> None:
global next_module_id

module_id = next_module_id
module_id_to_module[module_id] = module
next_module_id += 1
self.ignore_types = [torch.complex64, torch.int64]
self.min_tensor_size_bytes = 1 * 1024 * 1024
self.offload_ratio = max(0.0, min(1.0, offload_ratio))
self.tensors_offloaded = 0
self.tensors_kept_on_gpu = 0

# logger.info(f"This is module {id(module):#x}, {module_id}.")

def get_num_bytes_tensor(x: torch.Tensor) -> int:
# get the number of bytes in a tensor, for memory management purposes
return x.element_size() * x.nelement() #x.element_size() * x._base_storage().nbytes()

def pack_to_cpu(tensor: torch.Tensor) -> tuple[torch.device, torch.Tensor]:
if tensor.device.type == "cpu":
# logger.info(f"")
return (tensor.device, tensor)

num_bytes = get_num_bytes_tensor(tensor)
sizes = tensor.size()

device_tensor = tensor # rename for clarity
del tensor

# TODO: Insert optional policy for deciding whether to offload.
# Migrate to be like non-reentrant activation checkpointing in the
# future to reuse the selective activation checkpointing logic.
if (device_tensor.numel() < self.min_tensor_size_bytes) or (device_tensor.dtype in self.ignore_types):
# logger.info(f"Ignoring activation tensor of {num_bytes} bytes, size = {sizes}, dtype = {device_tensor.dtype}")
return (device_tensor.device, device_tensor)

should_offload = (self.tensors_offloaded / (self.tensors_offloaded + self.tensors_kept_on_gpu + 1) < self.offload_ratio)
# should_offload = random.random() < self.offload_ratio
if not should_offload:
self.tensors_kept_on_gpu += 1
return (device_tensor.device, device_tensor)

current_stream = torch.cuda.current_stream()

module_id_to_free = module_id - 1
if module_id_to_free in module_id_to_module:
# Have the first of module i to free all of module i-1
# logger.info(f"Trying to free {module_id_to_free}...")
module_to_free = module_id_to_module[module_id_to_free]
self.free_packed_device_tensors(module_to_free)

offload_stream.wait_stream(current_stream)
with torch.cuda.stream(offload_stream):
# logger.info(f"Copying activation tensor of {num_bytes} bytes, size = {sizes}, dtype = {device_tensor.dtype} to CPU...")
cpu_tensor = device_tensor.to(torch.device("cpu"), non_blocking=True)
# logger.info(f"Record d2h event.")
d2h_event = offload_stream.record_event()
self.tensors_offloaded += 1

module_to_cpu_tensors[module].append(cpu_tensor)
module_to_pack_infos[module].append(PackInfo(d2h_event, device_tensor))
return (device_tensor.device, cpu_tensor)

def unpack_from_cpu(packed) -> torch.Tensor:
device, tensor = packed
if tensor.device == device:
return tensor
assert tensor.device == torch.device("cpu"), f"{tensor.device}"

cpu_tensor = tensor # rename for clarity
del tensor

# Clear any existing refs from forward (this should only happen for
# the last module)
self.free_packed_device_tensors(module)

current_stream = torch.cuda.current_stream()
module_id_to_prealloc = module_id - 1

if module_id_to_prealloc in module_id_to_module:
module_to_prealloc = module_id_to_module[module_id_to_prealloc]
if module_to_prealloc in module_to_cpu_tensors:
cpu_tensors = module_to_cpu_tensors[module_to_prealloc]
for _cpu_tensor in cpu_tensors:
cpu_tensor_to_unpack_info[_cpu_tensor] = UnpackInfo(
current_stream.record_event(),
torch.empty_like(_cpu_tensor, device=device),
)
del module_to_cpu_tensors[module_to_prealloc]

if cpu_tensor in cpu_tensor_to_unpack_info: # prefetched
event, device_tensor = cpu_tensor_to_unpack_info[cpu_tensor]
offload_stream.wait_event(event)
del cpu_tensor_to_unpack_info[cpu_tensor]
else:
device_tensor = torch.empty_like(cpu_tensor, device=device)
# Preallocate the rest of the 1st backward module
for _cpu_tensor in module_to_cpu_tensors[module]:
if _cpu_tensor is cpu_tensor:
continue
cpu_tensor_to_unpack_info[_cpu_tensor] = UnpackInfo(
current_stream.record_event(),
torch.empty_like(_cpu_tensor, device=device),
)
del module_to_cpu_tensors[module]
offload_stream.wait_stream(current_stream)

with torch.cuda.stream(offload_stream):
device_tensor.copy_(cpu_tensor, non_blocking=True)
current_stream.wait_stream(offload_stream)

return device_tensor

super().__init__(pack_to_cpu, unpack_from_cpu)

def free_packed_device_tensors(self, module: torch.nn.Module):
if module in module_to_pack_infos:
# logger.info(f"Trying to free packed device tensors from module {id(module):#x}")
if infos := module_to_pack_infos[module]:
# Make sure that the default stream does not reuse any of
# the previous activation memory until the D2H finish
torch.cuda.current_stream().wait_event(infos[-1].d2h_event)
del module_to_pack_infos[module]

23 changes: 15 additions & 8 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import importlib
import os
import time
from datetime import timedelta
from typing import Any, Generator, Iterable, Optional

import torch
from torch.distributed.elastic.multiprocessing.errors import record

import torchtitan.components.ft as ft
import torchtitan.protocols.train_spec as train_spec_module
from torch.distributed.elastic.multiprocessing.errors import record
from torchtitan.components.checkpoint import CheckpointManager
from torchtitan.components.metrics import (
build_metrics_processor,
Expand Down Expand Up @@ -359,13 +360,19 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
)
else:
# Non-PP forward / backward
with self.train_context(optional_context_parallel_ctx):
assert len(model_parts) == 1
pred = model_parts[0](inputs)
loss = self.loss_fn(pred, labels)
# need to free to before bwd to avoid peaking memory
del pred
loss.backward()
offload_context = contextlib.nullcontext()
if self.job_config.offloading.mode == "sequential":
offload_ratio = self.job_config.offloading.offload_ratio
offload_context = torch.autograd.graph.manage_activations(offload_ratio=offload_ratio)

with offload_context:
with self.train_context(optional_context_parallel_ctx):
assert len(model_parts) == 1
pred = model_parts[0](inputs)
loss = self.loss_fn(pred, labels)
# need to free to before bwd to avoid peaking memory
del pred
loss.backward()

dist_utils.clip_grad_norm_(
[p for m in model_parts for p in m.parameters()],
Expand Down