Skip to content

Commit 0c0b722

Browse files
authored
Use depends in pipeline parallel (#483)
1 parent dcb4b9b commit 0c0b722

File tree

3 files changed

+5
-1
lines changed

3 files changed

+5
-1
lines changed

mlx_lm/models/deepseek_v2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,8 @@ def __call__(
414414
# Send to the next process in the pipeline
415415
if pipeline_rank != 0:
416416
h = mx.distributed.send(h, (pipeline_rank - 1) % pipeline_size)
417+
if cache[-1] is not None:
418+
cache[-1].keys = mx.depends(cache[-1].keys, h)
417419

418420
# Broadcast h while keeping it in the graph
419421
h = mx.distributed.all_gather(h)[: h.shape[0]]

mlx_lm/models/deepseek_v3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,8 @@ def __call__(
446446
# Send to the next process in the pipeline
447447
if pipeline_rank != 0:
448448
h = mx.distributed.send(h, (pipeline_rank - 1) % pipeline_size)
449+
if cache[-1] is not None:
450+
cache[-1].keys = mx.depends(cache[-1].keys, h)
449451

450452
# Broadcast h while keeping it in the graph
451453
h = mx.distributed.all_gather(h)[: h.shape[0]]

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
mlx>=0.29.1
1+
mlx>=0.29.2
22
numpy
33
transformers>=4.39.3
44
protobuf

0 commit comments

Comments
 (0)