Skip to content

Commit b7aa7cf

Browse files
committed
refactor: recovery after agent kill and rejoin
In the case when an agent that deployed workers is destroyed, the current code implementation would skip the agent that previously deployed failed workers if it wasn't found during the first iteration of recovery, continuing with remaining agents. We had to refactor the code to remove this limitation. So when an agent fails and is restarted, it will be included in the recovery process. Added changes for when an agent joins controller, to also update all existing job contexts with the details about this agent. So whenever an agent joins, all job contexts will be able to use its resources.
1 parent 2509d39 commit b7aa7cf

File tree

2 files changed

+18
-28
lines changed

2 files changed

+18
-28
lines changed

infscale/controller/controller.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ async def run(self):
9999

100100
await self.apiserver.run()
101101

102+
def _manage_agent_in_job_ctx(self) -> None:
103+
"""Manage agent data in job contexts when a new agent registers."""
104+
for ctx in self.job_contexts.values():
105+
ctx.manage_agent_metadata()
106+
102107
async def handle_register(self, req: pb2.RegReq) -> tuple[bool, str]:
103108
"""Handle registration message."""
104109
if req.id in self.agent_contexts:
@@ -107,6 +112,7 @@ async def handle_register(self, req: pb2.RegReq) -> tuple[bool, str]:
107112
self.agent_contexts[req.id] = AgentContext(self, req.id, req.ip)
108113
# since registration is done, let's keep agent context alive
109114
self.agent_contexts[req.id].keep_alive()
115+
self._manage_agent_in_job_ctx()
110116

111117
return True, ""
112118

infscale/controller/job_context.py

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -518,25 +518,9 @@ def _get_wrk_resources_map(self, wrk_ids: set[str]) -> dict[str, str]:
518518
agent_gpu_map: dict[str, set[int]] = {}
519519

520520
for wrk_id in wrk_ids:
521-
curr_agent = self._get_curr_agent_data(wrk_id)
522-
assign_success = False
523-
# current agent id might not be available in the case of
524-
# recover due to agent failure
525-
curr_agent_id = curr_agent.id if curr_agent else ""
526-
527-
if curr_agent:
528-
assign_success = self._assign_available_gpu_to_worker(
529-
curr_agent_id,
530-
agent_resources[curr_agent_id],
531-
wrk_id,
532-
wrk_agent_map,
533-
agent_gpu_map,
534-
)
535-
536-
if not assign_success:
537-
assign_success = self._search_gpu_on_all_agents(
538-
agent_resources, curr_agent_id, wrk_id, wrk_agent_map, agent_gpu_map
539-
)
521+
assign_success = self._search_gpu_on_all_agents(
522+
agent_resources, wrk_id, wrk_agent_map, agent_gpu_map
523+
)
540524

541525
if not assign_success:
542526
# if no resources, return and let while loop continue
@@ -594,7 +578,6 @@ def _assign_available_gpu_to_worker(
594578
def _search_gpu_on_all_agents(
595579
self,
596580
agent_resources: dict[str, AgentResources],
597-
curr_agent_id: str,
598581
wrk_id: str,
599582
wrk_agent_map: dict[str, tuple[str, int]],
600583
agent_gpu_map: dict[str, set[int]],
@@ -604,15 +587,16 @@ def _search_gpu_on_all_agents(
604587
Returns:
605588
bool: True if a GPU was successfully assigned, False otherwise.
606589
"""
590+
assign_success = False
607591
for agent_id, resources in agent_resources.items():
608-
if agent_id == curr_agent_id:
609-
continue
610-
611-
return self._assign_available_gpu_to_worker(
592+
assign_success = self._assign_available_gpu_to_worker(
612593
agent_id, resources, wrk_id, wrk_agent_map, agent_gpu_map
613594
)
595+
596+
if assign_success:
597+
break
614598

615-
return False
599+
return assign_success
616600

617601
def enum_(self) -> JobStateEnum:
618602
"""Return recovery state enum."""
@@ -1207,7 +1191,7 @@ def _get_state_class(self, state_enum: JobStateEnum):
12071191
}
12081192
return state_mapping[state_enum]
12091193

1210-
def _manage_agent_metadata(self) -> None:
1194+
def manage_agent_metadata(self) -> None:
12111195
"""Manage agent metadata by create/update/delete."""
12121196
agent_contexts = self.ctrl.agent_contexts
12131197

@@ -1380,7 +1364,7 @@ async def __update(self):
13801364
# DO NOT call this method in job_context instance or any other places.
13811365
# Call it only in methods of a state instance
13821366
# (e.g., RunningState, RecoveryState, etc).
1383-
self._manage_agent_metadata()
1367+
self.manage_agent_metadata()
13841368

13851369
try:
13861370
self.process_cfg()
@@ -1448,7 +1432,7 @@ async def __start(self):
14481432
# DO NOT call this method in job_context instance or any other places.
14491433
# Call it only in methods of a state instance
14501434
# (e.g., ReadyState, CompleteState, etc).
1451-
self._manage_agent_metadata()
1435+
self.manage_agent_metadata()
14521436

14531437
self._check_agent_info()
14541438

0 commit comments

Comments
 (0)