Skip to content

Commit 92ad240

Browse files
committed
option 2 - call work.wait inside future callback
1 parent 405dc6e commit 92ad240

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torchft/manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,12 +397,13 @@ def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work
397397
def callback(
398398
fut: torch.futures.Future[List[torch.Tensor]],
399399
) -> torch.Tensor:
400-
nonlocal tensor, stream, num_participants
400+
nonlocal tensor, stream, num_participants, work
401401

402402
# change the stream to avoid making the callback stream
403403
# dependent on process group stream running the allreduce
404404
with torch.cuda.stream(stream) if stream is not None else nullcontext():
405405
# Setup stream dependency
406+
work.wait()
406407
fut.wait()
407408
fut.value()
408409
tensor /= num_participants

0 commit comments

Comments
 (0)