Skip to content

Commit 6e5af1e

Browse files
authored
Revert tracking of Work status for FlightRecorder in ProcessGroupXCCL (#2076)
The callback used to track the work status in ProcessGroupXCCL was causing an unintended memory leak by maintaining the work objects and therefor the stashed tensors. For now, I'm removing the callback and I have added a unit test to ensure this memory leak isn't returning. Fix #2084
1 parent bc52e63 commit 6e5af1e

File tree

3 files changed

+36
-17
lines changed

3 files changed

+36
-17
lines changed

src/xccl/ProcessGroupXCCL.cpp

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -437,17 +437,6 @@ void ProcessGroupXCCL::setEnqueuedPgStatus(
437437
pgStatus_->lastEnqueuedNumelOut = work->numelOut_;
438438
}
439439

440-
void ProcessGroupXCCL::setCompletedPgStatus(
441-
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work) {
442-
pgStatus_->lastCompletedSeq = static_cast<int64_t>(work->getSequencenumber());
443-
pgStatus_->lastCompletedWorkName = opTypeToString(work->opType_);
444-
pgStatus_->lastCompletedNumelIn = work->numelIn_;
445-
pgStatus_->lastCompletedNumelOut = work->numelOut_;
446-
// To avoid complexity, we're not computing duration.
447-
FlightRecorderXCCL::get()->retire_id(
448-
work->trace_id_, /*compute_duration*/ false);
449-
}
450-
451440
void ProcessGroupXCCL::setSequenceNumberForGroup() {}
452441

453442
uint64_t ProcessGroupXCCL::getSequenceNumberForGroup() {
@@ -777,8 +766,12 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
777766
work->future_ = c10::make_intrusive<at::ivalue::Future>(
778767
c10::ListType::create(c10::TensorType::get()), devices);
779768
work->future_->markCompleted(at::IValue(*work->outputs_));
769+
auto id = work->trace_id_;
780770
work->future_->addCallback(
781-
[this, work](at::ivalue::Future&) { this->setCompletedPgStatus(work); });
771+
[id](at::ivalue::Future&) {
772+
FlightRecorderXCCL::get()->retire_id(id, /*compute_duration*/ false);
773+
},
774+
/*use_future*/ false);
782775
work->blockingWait_ = blockingWait_;
783776

784777
work->numelIn_ = 0;
@@ -889,9 +882,12 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
889882
work->future_ = c10::make_intrusive<at::ivalue::Future>(
890883
c10::ListType::create(c10::TensorType::get()), devices);
891884
work->future_->markCompleted(at::IValue(*work->outputs_));
892-
work->future_->addCallback([this, work](at::ivalue::Future&) {
893-
this->setCompletedPgStatus(work);
894-
});
885+
auto id = work->trace_id_;
886+
work->future_->addCallback(
887+
[id](at::ivalue::Future&) {
888+
FlightRecorderXCCL::get()->retire_id(id, /*compute_duration*/ false);
889+
},
890+
/*use_future*/ false);
895891

896892
work->numelIn_ = work->numelOut_ = tensor.numel();
897893
setEnqueuedPgStatus(work);

src/xccl/ProcessGroupXCCL.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -424,8 +424,6 @@ class TORCH_API ProcessGroupXCCL : public Backend {
424424

425425
const std::vector<uint64_t>& groupRanks() const;
426426
void setEnqueuedPgStatus(c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work);
427-
void setCompletedPgStatus(
428-
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work);
429427
bool dumpDebuggingInfo(bool includeStackTrace = true);
430428

431429
protected:

test/xpu/distributed/test_c10d_xccl.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,31 @@ def test_nan_assert(self, type):
365365
# reset env
366366
os.environ["TORCH_XCCL_NAN_CHECK"] = "0"
367367

368+
@requires_xccl()
369+
@skip_if_lt_x_gpu(2)
370+
def test_oom(self):
371+
pg = self._create_process_group_xccl()
372+
dp_ranks = range(0, self.world_size)
373+
dp_group = c10d.new_group(dp_ranks)
374+
device = torch.device(f"xpu:{self.rank}")
375+
torch.xpu.set_device(device)
376+
377+
shape = (16384 * 2, 16384 * 2)
378+
weight = torch.ones(shape, device=device).half()
379+
gradient = torch.zeros(shape, device=device).half()
380+
ret = torch.randn(shape, device=device).half()
381+
382+
for iter in range(50):
383+
output = torch.empty_like(ret)
384+
output = ret + weight + gradient
385+
ret = torch.nn.functional.linear(output, weight=ret)
386+
dist.all_reduce(ret, op=dist.ReduceOp.SUM)
387+
torch.xpu.synchronize()
388+
self.assertLess(
389+
torch.xpu.max_memory_allocated(),
390+
torch.xpu.max_memory_reserved() * 2,
391+
)
392+
368393

369394
class CommTest(MultiProcessTestCase):
370395
@property

0 commit comments

Comments
 (0)