|
29 | 29 | from nvflare.security.logging import secure_format_exception |
30 | 30 | from nvflare.widgets.info_collector import GroupInfoCollector, InfoCollector |
31 | 31 |
|
| 32 | +_TASK_CHECK_RESULT_OK = 0 |
| 33 | +_TASK_CHECK_RESULT_TRY_AGAIN = 1 |
| 34 | +_TASK_CHECK_RESULT_TASK_GONE = 2 |
| 35 | + |
32 | 36 |
|
33 | 37 | class ClientRunnerConfig(object): |
34 | 38 | def __init__( |
@@ -107,7 +111,8 @@ def __init__( |
107 | 111 | self.task_lock = threading.Lock() |
108 | 112 | self.end_run_fired = False |
109 | 113 | self.end_run_lock = threading.Lock() |
110 | | - |
| 114 | + self.task_check_timeout = 5.0 |
| 115 | + self.task_check_interval = 5.0 |
111 | 116 | self._register_aux_message_handler(engine) |
112 | 117 |
|
113 | 118 | def _register_aux_message_handler(self, engine): |
@@ -473,19 +478,108 @@ def fetch_and_run_one_task(self, fl_ctx) -> (float, bool): |
473 | 478 | if cookie_jar: |
474 | 479 | task_reply.set_cookie_jar(cookie_jar) |
475 | 480 |
|
476 | | - reply_sent = self.engine.send_task_result(task_reply, fl_ctx) |
477 | | - if reply_sent: |
478 | | - self.log_info(fl_ctx, "result sent to server for task: name={}, id={}".format(task.name, task.task_id)) |
479 | | - else: |
480 | | - self.log_error( |
481 | | - fl_ctx, |
482 | | - "failed to send result to server for task: name={}, id={}".format(task.name, task.task_id), |
483 | | - ) |
| 481 | + self._send_task_result(task_reply, task.task_id, fl_ctx) |
484 | 482 | self.log_debug(fl_ctx, "firing event EventType.AFTER_SEND_TASK_RESULT") |
485 | 483 | self.fire_event(EventType.AFTER_SEND_TASK_RESULT, fl_ctx) |
486 | 484 |
|
487 | 485 | return task_fetch_interval, True |
488 | 486 |
|
| 487 | + def _send_task_result(self, result: Shareable, task_id: str, fl_ctx: FLContext): |
| 488 | + try_count = 1 |
| 489 | + while True: |
| 490 | + self.log_info(fl_ctx, f"try #{try_count}: sending task result to server") |
| 491 | + |
| 492 | + if self.asked_to_stop: |
| 493 | + self.log_info(fl_ctx, "job aborted: stopped trying to send result") |
| 494 | + return False |
| 495 | + |
| 496 | + try_count += 1 |
| 497 | + rc = self._try_send_result_once(result, task_id, fl_ctx) |
| 498 | + |
| 499 | + if rc == _TASK_CHECK_RESULT_OK: |
| 500 | + return True |
| 501 | + elif rc == _TASK_CHECK_RESULT_TASK_GONE: |
| 502 | + return False |
| 503 | + else: |
| 504 | + # retry |
| 505 | + time.sleep(self.task_check_interval) |
| 506 | + |
| 507 | + def _try_send_result_once(self, result: Shareable, task_id: str, fl_ctx: FLContext): |
| 508 | + # wait until server is ready to receive |
| 509 | + while True: |
| 510 | + if self.asked_to_stop: |
| 511 | + return _TASK_CHECK_RESULT_TASK_GONE |
| 512 | + |
| 513 | + rc = self._check_task_once(task_id, fl_ctx) |
| 514 | + if rc == _TASK_CHECK_RESULT_OK: |
| 515 | + break |
| 516 | + elif rc == _TASK_CHECK_RESULT_TASK_GONE: |
| 517 | + return rc |
| 518 | + else: |
| 519 | + # try again |
| 520 | + time.sleep(self.task_check_interval) |
| 521 | + |
| 522 | + # try to send the result |
| 523 | + self.log_info(fl_ctx, "start to send task result to server") |
| 524 | + reply_sent = self.engine.send_task_result(result, fl_ctx) |
| 525 | + if reply_sent: |
| 526 | + self.log_info(fl_ctx, "task result sent to server") |
| 527 | + return _TASK_CHECK_RESULT_OK |
| 528 | + else: |
| 529 | + self.log_error(fl_ctx, "failed to send task result to server - will try again") |
| 530 | + return _TASK_CHECK_RESULT_TRY_AGAIN |
| 531 | + |
| 532 | + def _check_task_once(self, task_id: str, fl_ctx: FLContext) -> int: |
| 533 | + """This method checks whether the server is still waiting for the specified task. |
| 534 | + The real reason for this method is to fight against unstable network connections. |
| 535 | + We try to make sure that when we send task result to the server, the connection is available. |
| 536 | + If the task check succeeds, then the network connection is likely to be available. |
| 537 | + Otherwise, we keep retrying until task check succeeds or the server tells us that the task is gone (timed out). |
| 538 | +
|
| 539 | + Args: |
| 540 | + task_id: |
| 541 | + fl_ctx: |
| 542 | +
|
| 543 | + Returns: |
| 544 | +
|
| 545 | + """ |
| 546 | + self.log_info(fl_ctx, "checking task ...") |
| 547 | + task_check_req = Shareable() |
| 548 | + task_check_req.set_header(ReservedKey.TASK_ID, task_id) |
| 549 | + resp = self.engine.send_aux_request( |
| 550 | + targets=[FQCN.ROOT_SERVER], |
| 551 | + topic=ReservedTopic.TASK_CHECK, |
| 552 | + request=task_check_req, |
| 553 | + timeout=self.task_check_timeout, |
| 554 | + fl_ctx=fl_ctx, |
| 555 | + optional=True, |
| 556 | + ) |
| 557 | + if resp and isinstance(resp, dict): |
| 558 | + reply = resp.get(FQCN.ROOT_SERVER) |
| 559 | + if not isinstance(reply, Shareable): |
| 560 | + self.log_error(fl_ctx, f"bad task_check reply from server: expect Shareable but got {type(reply)}") |
| 561 | + return _TASK_CHECK_RESULT_TRY_AGAIN |
| 562 | + |
| 563 | + rc = reply.get_return_code() |
| 564 | + if rc == ReturnCode.OK: |
| 565 | + return _TASK_CHECK_RESULT_OK |
| 566 | + elif rc == ReturnCode.COMMUNICATION_ERROR: |
| 567 | + self.log_error(fl_ctx, f"failed task_check: {rc}") |
| 568 | + return _TASK_CHECK_RESULT_TRY_AGAIN |
| 569 | + elif rc == ReturnCode.SERVER_NOT_READY: |
| 570 | + self.log_error(fl_ctx, f"server rejected task_check: {rc}") |
| 571 | + return _TASK_CHECK_RESULT_TRY_AGAIN |
| 572 | + elif rc == ReturnCode.TASK_UNKNOWN: |
| 573 | + self.log_error(fl_ctx, f"task no longer exists on server: {rc}") |
| 574 | + return _TASK_CHECK_RESULT_TASK_GONE |
| 575 | + else: |
| 576 | + # this should never happen |
| 577 | + self.log_error(fl_ctx, f"programming error: received {rc} from server") |
| 578 | + return _TASK_CHECK_RESULT_OK # try to push the result regardless |
| 579 | + else: |
| 580 | + self.log_error(fl_ctx, f"bad task_check reply from server: invalid resp {type(resp)}") |
| 581 | + return _TASK_CHECK_RESULT_TRY_AGAIN |
| 582 | + |
489 | 583 | def run(self, app_root, args): |
490 | 584 | self.init_run(app_root, args) |
491 | 585 |
|
|
0 commit comments