diff --git a/tests/integration_tests.py b/tests/integration_tests.py index 3ccbc1890..dca2610be 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -118,6 +118,21 @@ def build_test_list(): "Checkpoint Integration Test - Save Load Full Checkpoint", "full_checkpoint", ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + "--checkpoint.enable_hf_safetensors_format", + ], + [ + "--checkpoint.enable_checkpoint", + "--checkpoint.enable_hf_safetensors_format", + "--training.steps 20", + ], + ], + "Checkpoint Integration Test - Save Load Full Checkpoint", + "full_checkpoint_hf_safetensors", + ), OverrideDefinitions( [ [ diff --git a/tests/unit_tests/test_checkpoint.py b/tests/unit_tests/test_checkpoint.py index 3317a51fe..2f8127bfd 100644 --- a/tests/unit_tests/test_checkpoint.py +++ b/tests/unit_tests/test_checkpoint.py @@ -144,7 +144,7 @@ def tearDown(self): shutil.rmtree(self.base_temp_dir) time.sleep(0.1) - def fake_save(self, state_dict: dict, checkpoint_id: str): + def fake_save(self, state_dict: dict, checkpoint_id: str, storage_writer=None): os.makedirs(checkpoint_id, exist_ok=True) sd_to_save = {} for key, val in state_dict.items(): @@ -584,7 +584,7 @@ def __init__(self): @mock.patch("torchtitan.components.checkpoint.dcp.load") @mock.patch("torchtitan.components.checkpoint.dcp.save") def test_verify_prefix(self, mock_save, mock_load, mock_rank): - def fake_save(state_dict: dict, checkpoint_id: str): + def fake_save(state_dict: dict, checkpoint_id: str, storage_writer=None): self.assertIn("bias", state_dict) self.assertIn("weight", state_dict) # No model prefix diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index ff055cbe7..4b56d16d4 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -12,12 +12,17 @@ import shutil import threading import time +from concurrent.futures import Future from typing import Any import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp import torch.nn as nn +from torch.distributed.checkpoint import ( + HuggingFaceStorageReader, + HuggingFaceStorageWriter, +) from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, @@ -92,12 +97,6 @@ class SaveDone: pass -@torch.no_grad() -def save_with_gc(state, checkpoint_id): - dcp.save(state, checkpoint_id=checkpoint_id) - GarbageCollection.collect("GC collection invoked by checkpointer.") - - def purge_thread(purge_queue: queue.Queue): """Thread to purge the old checkpoints. @@ -190,6 +189,7 @@ def __init__( ) -> None: ckpt_config = job_config.checkpoint self.enable_checkpoint = ckpt_config.enable_checkpoint + self.enable_hf_safetensors_format = ckpt_config.enable_hf_safetensors_format self.ft_manager = ft_manager.manager if ft_manager.enabled else None if self.ft_manager: @@ -312,6 +312,72 @@ def close(self): if self.stager is not None: self.stager.close() + @torch.no_grad() + def dcp_save( + self, + state_dict: dict[str, Any], + checkpoint_id: str, + async_mode: AsyncMode, + enable_garbage_collection: bool = False, + is_last_step: bool = False + ) -> Future | None: + """Save the checkpoint with dcp. + Args: + state_dict (dict): The state dict to save. + checkpoint_id (str): The checkpoint id to save. + is_async (bool): Whether the checkpoint is async. + + Returns: + Future: The future object if the checkpoint is async, otherwise None. + """ + ret: Future | None = None + + storage_writer = ( + HuggingFaceStorageWriter( + path=checkpoint_id, save_distributed=True, enable_consolidation=is_last_step, + ) + if self.enable_hf_safetensors_format + else None + ) + id = checkpoint_id if not self.enable_hf_safetensors_format else None + if async_mode == AsyncMode.ASYNC: + ret = dcp.async_save( + state_dict, + storage_writer=storage_writer, + checkpoint_id=id, + process_group=self.pg, + ) + elif async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: + ret = dcp.async_save( + state_dict, + storage_writer=storage_writer, + checkpoint_id=id, + process_group=self.pg, + async_checkpointer_type=AsyncCheckpointerType.PROCESS, + async_stager=self.stager, + ) + else: + ret = dcp.save(state_dict, storage_writer=storage_writer, checkpoint_id=id) + + if enable_garbage_collection: + GarbageCollection.collect("GC collection invoked by checkpointer.") + + return ret + + def dcp_load(self, state_dict: dict[str, Any], checkpoint_id: str) -> None: + """Load the checkpoint with dcp. + Args: + state_dict (dict): The state dict to load. + checkpoint_id (str): The checkpoint id to load. + hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format. + """ + + if self.enable_hf_safetensors_format: + storage_reader = HuggingFaceStorageReader(path=checkpoint_id) + dcp.load(state_dict, storage_reader=storage_reader) + else: + dcp.load(state_dict, checkpoint_id=checkpoint_id) + @torch.no_grad() def save(self, curr_step: int, last_step: bool = False) -> None: """Save the checkpoint for the current step. @@ -352,23 +418,26 @@ def save(self, curr_step: int, last_step: bool = False) -> None: GarbageCollection.collect("GC collection invoked by checkpointer.") if self.stager is None: self.stager = DefaultStager(StagingOptions(True, True, True, True)) - result = dcp.async_save( + result = self.dcp_save( states, checkpoint_id=checkpoint_id, - process_group=self.pg, - async_checkpointer_type=AsyncCheckpointerType.PROCESS, - async_stager=self.stager, + async_mode=self.async_mode, ) self.save_future = result.upload_completion self.staging_future = result.staging_completion elif self.async_mode == AsyncMode.ASYNC: GarbageCollection.collect("GC collection invoked by checkpointer.") - self.save_future = dcp.async_save( - states, checkpoint_id=checkpoint_id, process_group=self.pg + self.save_future = self.dcp_save( + states, checkpoint_id=checkpoint_id, async_mode=self.async_mode ) GarbageCollection.collect("GC collection invoked by checkpointer.") else: - save_with_gc(states, checkpoint_id=checkpoint_id) + self.dcp_save( + states, + checkpoint_id=checkpoint_id, + async_mode=AsyncMode.DISABLED, + enable_garbage_collection=True, + ) self._purge_stale_checkpoints() logger.info( @@ -433,7 +502,10 @@ def load(self, step: int = -1) -> bool: logger.info(f"Loading the checkpoint from {checkpoint_id}.") begin = time.monotonic() states = self._states_to_load(model_only) - dcp.load(states, checkpoint_id=checkpoint_id) + self.dcp_load( + states, + checkpoint_id=checkpoint_id, + ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds." @@ -486,8 +558,8 @@ def _ft_save(self, step: int) -> None: begin = time.monotonic() self._async_wait() checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder()) - self.save_future = dcp.async_save( - self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg + self.save_future = self.dcp_save( + self.ft_states, checkpoint_id=checkpoint_id, async_mode=AsyncMode.ASYNC ) logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.") @@ -499,7 +571,10 @@ def _ft_load(self) -> None: begin = time.monotonic() logger.info(f"Loading the FT checkpoint at step {step}.") checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder()) - dcp.load(self.ft_states, checkpoint_id=checkpoint_id) + self.dcp_load( + self.ft_states, + checkpoint_id=checkpoint_id, + ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds." @@ -568,7 +643,13 @@ def _save_last_step(self, curr_step: int) -> None: logger.info(f"Saving a full checkpoint at last step, step {curr_step}.") states = self._flattened_model_states_sd() - save_with_gc(states, checkpoint_id=self._create_checkpoint_id(curr_step)) + self.dcp_save( + states, + checkpoint_id=self._create_checkpoint_id(curr_step), + async_mode=AsyncMode.DISABLED, + enable_garbage_collection=True, + is_last_step=True, + ) def _should_save(self, curr_step: int, last_step: bool = False) -> bool: if not self.enable_checkpoint: diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 3f8d25688..d567e987b 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -467,6 +467,14 @@ class Checkpoint: for many steps or checkpointing too frequently. The default value is False. """ + enable_hf_safetensors_format: bool = False + """ + Enable the use of safetensors format for checkpointing. This will save checkpoints + in safetensors format instead of the default DCP format. There will be a performance + cost in using this as we need to consolidate the sharded tensors to full tensors as + a separate step. The default value is False. + """ + @dataclass class ActivationCheckpoint: