File tree Expand file tree Collapse file tree 3 files changed +5
-1
lines changed Expand file tree Collapse file tree 3 files changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -414,6 +414,8 @@ def __call__(
414414 # Send to the next process in the pipeline
415415 if pipeline_rank != 0 :
416416 h = mx .distributed .send (h , (pipeline_rank - 1 ) % pipeline_size )
417+ if cache [- 1 ] is not None :
418+ cache [- 1 ].keys = mx .depends (cache [- 1 ].keys , h )
417419
418420 # Broadcast h while keeping it in the graph
419421 h = mx .distributed .all_gather (h )[: h .shape [0 ]]
Original file line number Diff line number Diff line change @@ -446,6 +446,8 @@ def __call__(
446446 # Send to the next process in the pipeline
447447 if pipeline_rank != 0 :
448448 h = mx .distributed .send (h , (pipeline_rank - 1 ) % pipeline_size )
449+ if cache [- 1 ] is not None :
450+ cache [- 1 ].keys = mx .depends (cache [- 1 ].keys , h )
449451
450452 # Broadcast h while keeping it in the graph
451453 h = mx .distributed .all_gather (h )[: h .shape [0 ]]
Original file line number Diff line number Diff line change 1- mlx >= 0.29.1
1+ mlx >= 0.29.2
22numpy
33transformers >= 4.39.3
44protobuf
You can’t perform that action at this time.
0 commit comments