Skip to content

Add support for saving HF format tensors with DCP #1351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,21 @@ def build_test_list():
"Checkpoint Integration Test - Save Load Full Checkpoint",
"full_checkpoint",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.enable_hf_safetensors_format",
],
[
"--checkpoint.enable_checkpoint",
"--checkpoint.enable_hf_safetensors_format",
"--training.steps 20",
],
],
"Checkpoint Integration Test - Save Load Full Checkpoint",
"full_checkpoint_hf_safetensors",
),
OverrideDefinitions(
[
[
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def tearDown(self):
shutil.rmtree(self.base_temp_dir)
time.sleep(0.1)

def fake_save(self, state_dict: dict, checkpoint_id: str):
def fake_save(self, state_dict: dict, checkpoint_id: str, storage_writer=None):
os.makedirs(checkpoint_id, exist_ok=True)
sd_to_save = {}
for key, val in state_dict.items():
Expand Down Expand Up @@ -584,7 +584,7 @@ def __init__(self):
@mock.patch("torchtitan.components.checkpoint.dcp.load")
@mock.patch("torchtitan.components.checkpoint.dcp.save")
def test_verify_prefix(self, mock_save, mock_load, mock_rank):
def fake_save(state_dict: dict, checkpoint_id: str):
def fake_save(state_dict: dict, checkpoint_id: str, storage_writer=None):
self.assertIn("bias", state_dict)
self.assertIn("weight", state_dict)
# No model prefix
Expand Down
117 changes: 99 additions & 18 deletions torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,17 @@
import shutil
import threading
import time
from concurrent.futures import Future
from typing import Any

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed.checkpoint import (
HuggingFaceStorageReader,
HuggingFaceStorageWriter,
)
from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
Expand Down Expand Up @@ -92,12 +97,6 @@ class SaveDone:
pass


@torch.no_grad()
def save_with_gc(state, checkpoint_id):
dcp.save(state, checkpoint_id=checkpoint_id)
GarbageCollection.collect("GC collection invoked by checkpointer.")


def purge_thread(purge_queue: queue.Queue):
"""Thread to purge the old checkpoints.

Expand Down Expand Up @@ -190,6 +189,7 @@ def __init__(
) -> None:
ckpt_config = job_config.checkpoint
self.enable_checkpoint = ckpt_config.enable_checkpoint
self.enable_hf_safetensors_format = ckpt_config.enable_hf_safetensors_format
self.ft_manager = ft_manager.manager if ft_manager.enabled else None

if self.ft_manager:
Expand Down Expand Up @@ -312,6 +312,72 @@ def close(self):
if self.stager is not None:
self.stager.close()

@torch.no_grad()
def dcp_save(
self,
state_dict: dict[str, Any],
checkpoint_id: str,
async_mode: AsyncMode,
enable_garbage_collection: bool = False,
is_last_step: bool = False
) -> Future | None:
"""Save the checkpoint with dcp.
Args:
state_dict (dict): The state dict to save.
checkpoint_id (str): The checkpoint id to save.
is_async (bool): Whether the checkpoint is async.

Returns:
Future: The future object if the checkpoint is async, otherwise None.
"""
ret: Future | None = None

storage_writer = (
HuggingFaceStorageWriter(
path=checkpoint_id, save_distributed=True, enable_consolidation=is_last_step,
)
if self.enable_hf_safetensors_format
else None
)
id = checkpoint_id if not self.enable_hf_safetensors_format else None
if async_mode == AsyncMode.ASYNC:
ret = dcp.async_save(
state_dict,
storage_writer=storage_writer,
checkpoint_id=id,
process_group=self.pg,
)
elif async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
ret = dcp.async_save(
state_dict,
storage_writer=storage_writer,
checkpoint_id=id,
process_group=self.pg,
async_checkpointer_type=AsyncCheckpointerType.PROCESS,
async_stager=self.stager,
)
else:
ret = dcp.save(state_dict, storage_writer=storage_writer, checkpoint_id=id)

if enable_garbage_collection:
GarbageCollection.collect("GC collection invoked by checkpointer.")

return ret

def dcp_load(self, state_dict: dict[str, Any], checkpoint_id: str) -> None:
"""Load the checkpoint with dcp.
Args:
state_dict (dict): The state dict to load.
checkpoint_id (str): The checkpoint id to load.
hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format.
"""

if self.enable_hf_safetensors_format:
storage_reader = HuggingFaceStorageReader(path=checkpoint_id)
dcp.load(state_dict, storage_reader=storage_reader)
else:
dcp.load(state_dict, checkpoint_id=checkpoint_id)

@torch.no_grad()
def save(self, curr_step: int, last_step: bool = False) -> None:
"""Save the checkpoint for the current step.
Expand Down Expand Up @@ -352,23 +418,26 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
GarbageCollection.collect("GC collection invoked by checkpointer.")
if self.stager is None:
self.stager = DefaultStager(StagingOptions(True, True, True, True))
result = dcp.async_save(
result = self.dcp_save(
states,
checkpoint_id=checkpoint_id,
process_group=self.pg,
async_checkpointer_type=AsyncCheckpointerType.PROCESS,
async_stager=self.stager,
async_mode=self.async_mode,
)
self.save_future = result.upload_completion
self.staging_future = result.staging_completion
elif self.async_mode == AsyncMode.ASYNC:
GarbageCollection.collect("GC collection invoked by checkpointer.")
self.save_future = dcp.async_save(
states, checkpoint_id=checkpoint_id, process_group=self.pg
self.save_future = self.dcp_save(
states, checkpoint_id=checkpoint_id, async_mode=self.async_mode
)
GarbageCollection.collect("GC collection invoked by checkpointer.")
else:
save_with_gc(states, checkpoint_id=checkpoint_id)
self.dcp_save(
states,
checkpoint_id=checkpoint_id,
async_mode=AsyncMode.DISABLED,
enable_garbage_collection=True,
)
self._purge_stale_checkpoints()

logger.info(
Expand Down Expand Up @@ -433,7 +502,10 @@ def load(self, step: int = -1) -> bool:
logger.info(f"Loading the checkpoint from {checkpoint_id}.")
begin = time.monotonic()
states = self._states_to_load(model_only)
dcp.load(states, checkpoint_id=checkpoint_id)
self.dcp_load(
states,
checkpoint_id=checkpoint_id,
)
GarbageCollection.collect("GC collection for checkpoint loading.")
logger.info(
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds."
Expand Down Expand Up @@ -486,8 +558,8 @@ def _ft_save(self, step: int) -> None:
begin = time.monotonic()
self._async_wait()
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
self.save_future = dcp.async_save(
self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg
self.save_future = self.dcp_save(
self.ft_states, checkpoint_id=checkpoint_id, async_mode=AsyncMode.ASYNC
)
logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.")

Expand All @@ -499,7 +571,10 @@ def _ft_load(self) -> None:
begin = time.monotonic()
logger.info(f"Loading the FT checkpoint at step {step}.")
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
dcp.load(self.ft_states, checkpoint_id=checkpoint_id)
self.dcp_load(
self.ft_states,
checkpoint_id=checkpoint_id,
)
GarbageCollection.collect("GC collection for checkpoint loading.")
logger.info(
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
Expand Down Expand Up @@ -568,7 +643,13 @@ def _save_last_step(self, curr_step: int) -> None:
logger.info(f"Saving a full checkpoint at last step, step {curr_step}.")
states = self._flattened_model_states_sd()

save_with_gc(states, checkpoint_id=self._create_checkpoint_id(curr_step))
self.dcp_save(
states,
checkpoint_id=self._create_checkpoint_id(curr_step),
async_mode=AsyncMode.DISABLED,
enable_garbage_collection=True,
is_last_step=True,
)

def _should_save(self, curr_step: int, last_step: bool = False) -> bool:
if not self.enable_checkpoint:
Expand Down
8 changes: 8 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,14 @@ class Checkpoint:
for many steps or checkpointing too frequently. The default value is False.
"""

enable_hf_safetensors_format: bool = False
"""
Enable the use of safetensors format for checkpointing. This will save checkpoints
in safetensors format instead of the default DCP format. There will be a performance
cost in using this as we need to consolidate the sharded tensors to full tensors as
a separate step. The default value is False.
"""


@dataclass
class ActivationCheckpoint:
Expand Down
Loading