-
Notifications
You must be signed in to change notification settings - Fork 419
Add support for saving HF format tensors with DCP #1351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@Saiteja64 This will conflict with your PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall the logic LGTM, please address comments and ensure that this PR doesn't conflict with the PR from @Saiteja64. Please also add a test result -- save a hf checkpoint and load one back and check the accuracy.
@@ -12,14 +12,19 @@ | |||
import shutil | |||
import threading | |||
import time | |||
from typing import Any | |||
from concurrent.futures import Future | |||
from typing import Any, Optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We use 3.10 type checking, so you don't need Optional
.
checkpoint_id: str, | ||
is_async: bool, | ||
hf_safetensors_format: bool, | ||
pg: Optional[dist.ProcessGroup] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pg: Optional[dist.ProcessGroup] = None, | |
pg: dist.ProcessGroup | None = None, |
def dcp_save( | ||
state_dict: dict[str, Any], | ||
checkpoint_id: str, | ||
is_async: bool, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Saiteja64 do we also need anther argument for ZOC?
is_async: bool, | ||
hf_safetensors_format: bool, | ||
pg: Optional[dist.ProcessGroup] = None, | ||
) -> Optional[Future]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
) -> Optional[Future]: | |
) -> Future | None: |
hf_safetensors_format: bool, | ||
pg: Optional[dist.ProcessGroup] = None, | ||
) -> Optional[Future]: | ||
"""Save the checkpoint with dcp. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add one empty line
checkpoint_id (str): The checkpoint id to save. | ||
is_async (bool): Whether the checkpoint is async. | ||
hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format. | ||
pg (Optional[dist.ProcessGroup]): The process group to use. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add the return value as well.
if hf_safetensors_format: | ||
storage_writer = HuggingFaceStorageWriter(path=checkpoint_id, save_sharded=True) | ||
if is_async: | ||
return dcp.async_save( | ||
state_dict, storage_writer=storage_writer, process_group=pg | ||
) | ||
else: | ||
return dcp.save(state_dict, storage_writer=storage_writer) | ||
else: | ||
if is_async: | ||
return dcp.async_save( | ||
state_dict, checkpoint_id=checkpoint_id, process_group=pg | ||
) | ||
else: | ||
return dcp.save(state_dict, checkpoint_id=checkpoint_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should simplify the function as follow
if hf_safetensors_format: | |
storage_writer = HuggingFaceStorageWriter(path=checkpoint_id, save_sharded=True) | |
if is_async: | |
return dcp.async_save( | |
state_dict, storage_writer=storage_writer, process_group=pg | |
) | |
else: | |
return dcp.save(state_dict, storage_writer=storage_writer) | |
else: | |
if is_async: | |
return dcp.async_save( | |
state_dict, checkpoint_id=checkpoint_id, process_group=pg | |
) | |
else: | |
return dcp.save(state_dict, checkpoint_id=checkpoint_id) | |
storage_writer = HuggingFaceStorageWriter(path=checkpoint_id, save_sharded=True) if hf_safetensors_format else None | |
checkpoint_id = checkpoint_id if not hf_safetensors_format else None | |
if is_async: | |
return dcp.async_save( | |
state_dict, storage_writer=storage_writer, checkpoint_id=checkpoint_id, process_group=pg | |
) | |
else: | |
return dcp.save(state_dict, storage_writer=storage_writer, checkpoint_id=checkpoint_id) |
def dcp_load( | ||
state_dict: dict[str, Any], checkpoint_id: str, hf_safetensors_format: bool | ||
) -> None: | ||
"""Load the checkpoint with dcp. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add one empty line below
enable_hf_safetensors_format: bool = False | ||
""" | ||
Enable the use of safetensors format for checkpointing. This will save checkpoints | ||
in safetensors format instead of the default DCP format. The default value is False. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also mention the possible performance penalty? It's not cost free, right?
If checkpoint.enable_hf_safetensors_format is set, then save the checkpoint with DCP HF components that will save the checkpoint in .safetensors files instead of regular DCP format.