Skip to content

Commit 647fcbd

Browse files
authored
Fix: deepspeed_utils.py
The function for handling data imbalance has been renamed to "deepspeed_join," and a bug where this function was not called has been fixed.
1 parent 43d9e54 commit 647fcbd

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/slam_llm/utils/deepspeed_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def decorated_main(cfg_passthrough: Optional[DictConfig] = None) -> Any:
107107

108108
return main_decorator
109109

110-
def slam_join(group_join):
110+
def deepspeed_join(group_join):
111111
"""
112112
Copy from wenet:https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/executor.py#L64
113113
"""
@@ -204,6 +204,8 @@ def train(
204204
else:
205205
pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", dynamic_ncols=True)
206206
for step, batch in enumerate(train_dataloader):
207+
if train_config.batching_strategy == "dynamic" and deepspeed_join(group_join):
208+
break
207209
for key in batch.keys():
208210
batch[key] = (
209211
batch[key].to(local_rank).half()

0 commit comments

Comments
 (0)