diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index e51547dbf..c84130537 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -55,6 +55,7 @@ class Perf: bwd_compute: float bwd_comms: float prefetch_compute: float = 0.0 + input_dist: float = 0.0 @property def total(self) -> float: @@ -73,12 +74,17 @@ def total(self) -> float: # benefit that 1) it enables the ScaleupProposer to explore the trade off # between increasing cache sizes vs more difficult bin-packing constraints. 2) # it helps balance the prefetch compute across the ranks. + + # Similarly, input_dist is often overlapped with compute kernels, but we + # conservatively add it as part of the total cost which models it as blocking. + return ( self.fwd_compute + self.bwd_compute + self.fwd_comms + self.bwd_comms + self.prefetch_compute + + self.input_dist ) def __add__(self, other: "Perf") -> "Perf": @@ -88,6 +94,7 @@ def __add__(self, other: "Perf") -> "Perf": bwd_compute=self.bwd_compute + other.bwd_compute, bwd_comms=self.bwd_comms + other.bwd_comms, prefetch_compute=self.prefetch_compute + other.prefetch_compute, + input_dist=self.input_dist + other.input_dist, ) def __hash__(self) -> int: @@ -98,6 +105,7 @@ def __hash__(self) -> int: self.bwd_compute, self.bwd_comms, self.prefetch_compute, + self.input_dist, ) )