-
Notifications
You must be signed in to change notification settings - Fork 39
setup stream dependencies inside work wrapper #248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
92ad240
to
b9cd277
Compare
4162c4a
to
e137ed1
Compare
torchft/manager.py
Outdated
return True | ||
|
||
def get_future(self) -> torch.futures.Future[torch.Tensor]: | ||
self.wait() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.wait() should be a blocking call, we probably want to invert this logic and make .wait() call get_future() instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, the .wait() and .get_future() should only be called when we want to block. for diloco, it's called at the sync step. iiuc manager allreduce method doesn't get called from hsdp so it's unrelated there. this also made me realize we call .get_future() for bucketized allreduce though where we don't want to block. thinking we can pass the callback to manager allreduce for that. guess it's not nice api vise to block in this method. so maybe we need option 1 but it seems to have some issues
the .wait() already calls .get_future(), we want to make sure when users of this api call .get_future(), we've already setup the stream dependencies for work and the tensor division
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_future now calls block_current_stream
14510dd
to
134e01f
Compare
83957be
to
d607c85
Compare
b16e0cb
to
91194a1
Compare
|
||
self._is_set_future_callback_called = False | ||
|
||
def _set_future_callback( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this need to live in the Work object? Can't we pass the stream + the future to the _WorkWrapper and have it manage things correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that would be ideal but doesn't work in all cases i think
- for nccl, we need to call work.wait() before doing everything in _set_future_callback(), otherwise stream dependency is not hooked up in the right order i think i.e. we could end up calling future.wait before work.wait
- for cpu, we can't call work.wait() because that'll block
these 2 conflict with each other, so this is what i came up with
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tushar00jain you can use work.synchronize()
to setup the dependency in a guaranteed non-blocking way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@d4l3k that's for nccl right? for nccl you mentioned block_current_stream
also just calls synchronize
but it also works in a non-blocking way for gloo. we needed block_current_stream
for that because i'm guessing syncrhronize
does't do that for gloo.
also based on our discussion offline, the current api's work for all cases and have the same semantics as the underlying process group work
- in torchft, we only ever use
work.wait()
, and we call it only when we need to synchronize- for nccl, and gloo with cuda, this sets up stream deps properly with a custom stream that we synchronize on to wait for the allreduce to finish along with the future associated with that work
- for gloo with cpu, it just blocks until the work is done. the future callbacks run after the work is done
- that was a lie, we also call
work.get_future()
in bucketized allreduce- in this case we call
block_current_stream
first to set up stream dep for nccl (just a proxy to work.synchronize), and gloo with cuda. we also add a callback to the future chain but carefully set up the stream dep after all the other stream deps have been set up. that's why we callblock_current_stream
in get_future anyway - for gloo with cpu, it doesn't call anything on work because futures anyway run after the work is done
- in this case we call
- we will call work.block_current_stream for hsdp in torchtitan -- this is pretty much the same as the case above for bucketized allreduce
- for ddp, we call get_future but don't expect users to do anything besides calling .wait on that future
In the future,
- we can consider creating our own future instead of using torch.futures.Future that sets up stream deps like the way we want it to
- consider simplifying the implementation of
_ManagedWork
(the above will also help us do that)
217f0d0
to
220ddf5
Compare
|
||
self._is_set_future_callback_called = True | ||
|
||
def wait(self, timeout: Optional[timedelta] = None) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.wait() should set a dependency between the work and the current stream -- it looks like we're running all operations on self._stream?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
accepting to unblock -- this seems like it will work for our current use cases
|
||
return True | ||
|
||
def block_current_stream(self, timeout: Optional[timedelta] = None) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we probably shouldn't rely on this until we've thought this through more / tested
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah we can test it more before we change the hsdp implementation. think we can also do some other alternative for bucketized allreduce and ddp without having to use block_current_stream
fc31fd4
to
2e5743d
Compare
Summary: - extend the work wrapper object to also do the division post allreduce - add api to block_current_stream on work wrapper so it can be used for HSDP
Summary:
Stack created with Sapling. Best reviewed with ReviewStack.