Skip to content

Add support for saving HF format tensors with DCP #1351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft

Conversation

ankitageorge
Copy link

@ankitageorge ankitageorge commented Jun 27, 2025

If checkpoint.enable_hf_safetensors_format is set, then save the checkpoint with DCP HF components that will save the checkpoint in .safetensors files instead of regular DCP format.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 27, 2025
@ankitageorge ankitageorge changed the title Dcp hf Add support for saving HF format tensors with DCP Jun 27, 2025
@fegin
Copy link
Contributor

fegin commented Jun 27, 2025

@Saiteja64 This will conflict with your PR.

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall the logic LGTM, please address comments and ensure that this PR doesn't conflict with the PR from @Saiteja64. Please also add a test result -- save a hf checkpoint and load one back and check the accuracy.

@@ -12,14 +12,19 @@
import shutil
import threading
import time
from typing import Any
from concurrent.futures import Future
from typing import Any, Optional
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use 3.10 type checking, so you don't need Optional.

checkpoint_id: str,
is_async: bool,
hf_safetensors_format: bool,
pg: Optional[dist.ProcessGroup] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pg: Optional[dist.ProcessGroup] = None,
pg: dist.ProcessGroup | None = None,

def dcp_save(
state_dict: dict[str, Any],
checkpoint_id: str,
is_async: bool,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Saiteja64 do we also need anther argument for ZOC?

is_async: bool,
hf_safetensors_format: bool,
pg: Optional[dist.ProcessGroup] = None,
) -> Optional[Future]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) -> Optional[Future]:
) -> Future | None:

hf_safetensors_format: bool,
pg: Optional[dist.ProcessGroup] = None,
) -> Optional[Future]:
"""Save the checkpoint with dcp.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add one empty line

checkpoint_id (str): The checkpoint id to save.
is_async (bool): Whether the checkpoint is async.
hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format.
pg (Optional[dist.ProcessGroup]): The process group to use.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the return value as well.

Comment on lines +116 to +130
if hf_safetensors_format:
storage_writer = HuggingFaceStorageWriter(path=checkpoint_id, save_sharded=True)
if is_async:
return dcp.async_save(
state_dict, storage_writer=storage_writer, process_group=pg
)
else:
return dcp.save(state_dict, storage_writer=storage_writer)
else:
if is_async:
return dcp.async_save(
state_dict, checkpoint_id=checkpoint_id, process_group=pg
)
else:
return dcp.save(state_dict, checkpoint_id=checkpoint_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should simplify the function as follow

Suggested change
if hf_safetensors_format:
storage_writer = HuggingFaceStorageWriter(path=checkpoint_id, save_sharded=True)
if is_async:
return dcp.async_save(
state_dict, storage_writer=storage_writer, process_group=pg
)
else:
return dcp.save(state_dict, storage_writer=storage_writer)
else:
if is_async:
return dcp.async_save(
state_dict, checkpoint_id=checkpoint_id, process_group=pg
)
else:
return dcp.save(state_dict, checkpoint_id=checkpoint_id)
storage_writer = HuggingFaceStorageWriter(path=checkpoint_id, save_sharded=True) if hf_safetensors_format else None
checkpoint_id = checkpoint_id if not hf_safetensors_format else None
if is_async:
return dcp.async_save(
state_dict, storage_writer=storage_writer, checkpoint_id=checkpoint_id, process_group=pg
)
else:
return dcp.save(state_dict, storage_writer=storage_writer, checkpoint_id=checkpoint_id)

def dcp_load(
state_dict: dict[str, Any], checkpoint_id: str, hf_safetensors_format: bool
) -> None:
"""Load the checkpoint with dcp.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add one empty line below

enable_hf_safetensors_format: bool = False
"""
Enable the use of safetensors format for checkpointing. This will save checkpoints
in safetensors format instead of the default DCP format. The default value is False.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also mention the possible performance penalty? It's not cost free, right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants