Skip to content

setup stream dependencies inside work wrapper #248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 1, 2025

Conversation

tushar00jain
Copy link
Contributor

@tushar00jain tushar00jain commented Jul 26, 2025

Summary:

  • extend the work wrapper object to also do the division post allreduce
  • add api to block_current_stream on work wrapper so it can be used for HSDP

Stack created with Sapling. Best reviewed with ReviewStack.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 26, 2025
@tushar00jain tushar00jain changed the title option 2 - call worl.wait inside future callback option 2 - call work.wait inside future callback Jul 26, 2025
@tushar00jain tushar00jain force-pushed the pr248 branch 2 times, most recently from 92ad240 to b9cd277 Compare July 26, 2025 04:43
@tushar00jain tushar00jain changed the title option 2 - call work.wait inside future callback option 2 - call work.wait inside wrapped work Jul 26, 2025
@tushar00jain tushar00jain force-pushed the pr248 branch 3 times, most recently from 4162c4a to e137ed1 Compare July 26, 2025 18:23
return True

def get_future(self) -> torch.futures.Future[torch.Tensor]:
self.wait()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.wait() should be a blocking call, we probably want to invert this logic and make .wait() call get_future() instead

Copy link
Contributor Author

@tushar00jain tushar00jain Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, the .wait() and .get_future() should only be called when we want to block. for diloco, it's called at the sync step. iiuc manager allreduce method doesn't get called from hsdp so it's unrelated there. this also made me realize we call .get_future() for bucketized allreduce though where we don't want to block. thinking we can pass the callback to manager allreduce for that. guess it's not nice api vise to block in this method. so maybe we need option 1 but it seems to have some issues

the .wait() already calls .get_future(), we want to make sure when users of this api call .get_future(), we've already setup the stream dependencies for work and the tensor division

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_future now calls block_current_stream

@tushar00jain tushar00jain force-pushed the pr248 branch 8 times, most recently from 14510dd to 134e01f Compare July 29, 2025 04:18
@tushar00jain tushar00jain force-pushed the pr248 branch 5 times, most recently from 83957be to d607c85 Compare July 29, 2025 06:13
@tushar00jain tushar00jain changed the title option 2 - call work.wait inside wrapped work setup stream dependencies inside work wrapper Jul 29, 2025
@tushar00jain tushar00jain force-pushed the pr248 branch 2 times, most recently from b16e0cb to 91194a1 Compare July 29, 2025 22:32

self._is_set_future_callback_called = False

def _set_future_callback(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need to live in the Work object? Can't we pass the stream + the future to the _WorkWrapper and have it manage things correctly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that would be ideal but doesn't work in all cases i think

  • for nccl, we need to call work.wait() before doing everything in _set_future_callback(), otherwise stream dependency is not hooked up in the right order i think i.e. we could end up calling future.wait before work.wait
  • for cpu, we can't call work.wait() because that'll block

these 2 conflict with each other, so this is what i came up with

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tushar00jain you can use work.synchronize() to setup the dependency in a guaranteed non-blocking way

Copy link
Contributor Author

@tushar00jain tushar00jain Jul 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@d4l3k that's for nccl right? for nccl you mentioned block_current_stream also just calls synchronize but it also works in a non-blocking way for gloo. we needed block_current_stream for that because i'm guessing syncrhronize does't do that for gloo.

also based on our discussion offline, the current api's work for all cases and have the same semantics as the underlying process group work

  • in torchft, we only ever use work.wait(), and we call it only when we need to synchronize
    • for nccl, and gloo with cuda, this sets up stream deps properly with a custom stream that we synchronize on to wait for the allreduce to finish along with the future associated with that work
    • for gloo with cpu, it just blocks until the work is done. the future callbacks run after the work is done
  • that was a lie, we also call work.get_future() in bucketized allreduce
    • in this case we call block_current_stream first to set up stream dep for nccl (just a proxy to work.synchronize), and gloo with cuda. we also add a callback to the future chain but carefully set up the stream dep after all the other stream deps have been set up. that's why we call block_current_stream in get_future anyway
    • for gloo with cpu, it doesn't call anything on work because futures anyway run after the work is done
  • we will call work.block_current_stream for hsdp in torchtitan -- this is pretty much the same as the case above for bucketized allreduce
  • for ddp, we call get_future but don't expect users to do anything besides calling .wait on that future

In the future,

  • we can consider creating our own future instead of using torch.futures.Future that sets up stream deps like the way we want it to
  • consider simplifying the implementation of _ManagedWork (the above will also help us do that)

@tushar00jain tushar00jain force-pushed the pr248 branch 17 times, most recently from 217f0d0 to 220ddf5 Compare July 31, 2025 02:47

self._is_set_future_callback_called = True

def wait(self, timeout: Optional[timedelta] = None) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.wait() should set a dependency between the work and the current stream -- it looks like we're running all operations on self._stream?

Copy link
Member

@d4l3k d4l3k left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accepting to unblock -- this seems like it will work for our current use cases


return True

def block_current_stream(self, timeout: Optional[timedelta] = None) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably shouldn't rely on this until we've thought this through more / tested

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we can test it more before we change the hsdp implementation. think we can also do some other alternative for bucketized allreduce and ddp without having to use block_current_stream

@tushar00jain tushar00jain force-pushed the pr248 branch 2 times, most recently from fc31fd4 to 2e5743d Compare August 1, 2025 06:31
Summary:
- extend the work wrapper object to also do the division post allreduce
- add api to block_current_stream on work wrapper so it can be used for HSDP
@tushar00jain tushar00jain merged commit d358fb4 into pytorch:main Aug 1, 2025
14 checks passed
@tushar00jain tushar00jain deleted the pr248 branch August 1, 2025 19:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants