Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion dlrover/python/common/global_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ class DefaultValues(object):
MAX_CKPT_THRESHOLD = 900 # seconds
MAX_AVG_STEPS = 50
FIRST_GROUP_IDX = 1000 # group idx initial value for group relaunch
MAX_RELAUNCH_COUNT = 3
MAX_RELAUNCH_COUNT = 3 # maximum node relaunch count
MAX_GROUP_RELAUNCH_COUNT = 3 # maximum node group relaunch count


class Context(Singleton):
Expand Down Expand Up @@ -145,6 +146,7 @@ def __init__(self):
# pre-check args
self.pre_check_operators = DefaultValues.PRE_CHECK_OPS
self.max_relaunch_count = DefaultValues.MAX_RELAUNCH_COUNT
self.max_group_relaunch_count = DefaultValues.MAX_GROUP_RELAUNCH_COUNT

def set_params_from_brain(self):
self.train_speed_record_num = self.get_param_value_from_brain(
Expand Down
10 changes: 10 additions & 0 deletions dlrover/python/master/node/dist_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ def __init__(
self._scaler: Scaler = job_scaler
self._init_training_node_manager()
self._relaunched_groups: List[int] = []
self._group_relaunch_count = 0
self._max_group_relaunch_count = _dlrover_context.max_relaunch_count

def start(self):
self._scaler.start()
Expand Down Expand Up @@ -917,6 +919,13 @@ def _should_relaunch_node_group(self, node_group: int) -> bool:
f"{self._enable_relaunch_node}, {node_check}, {job_ctx.get_job_stage()}"
)

if self._group_relaunch_count > self._max_group_relaunch_count:
logger.info(
f"Node group {node_group} has exceeded max relaunch count: "
f"{self._group_relaunch_count}/{self._max_group_relaunch_count}"
)
return False

return should_relaunch

def _should_relaunch(
Expand Down Expand Up @@ -1089,6 +1098,7 @@ def _relaunch_node_group(self, node_group: int):

self._relaunched_groups.append(node_group)
self._scaler.scale(plan)
self._group_relaunch_count += 1
return plan

def clear_exited_nodes(self):
Expand Down
15 changes: 15 additions & 0 deletions dlrover/python/tests/test_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,21 @@ def test_relaunch_node_group(self):
manager._init_nodes()
manager._scaler.scale = mock.MagicMock(return_value=None)

manager._max_group_relaunch_count = -1
self.job_context.clear_job_node_groups()
node = Node(
NodeType.WORKER,
0,
rank_index=0,
status=NodeStatus.PENDING,
node_group=0,
node_group_size=1,
relaunchable=True,
)
self.job_context.update_job_node_by_group(node)
self.assertFalse(manager._should_relaunch_node_group(0))
manager._max_group_relaunch_count = 3

self.job_context.clear_job_node_groups()
node = Node(
NodeType.WORKER,
Expand Down