Skip to content

Commit 7b01b0f

Browse files
authored
Added check for duplicate RM request (#2858)
* Added check for duplicate RM request * Addressed PR comment
1 parent dd90ddb commit 7b01b0f

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

nvflare/apis/utils/reliable_message.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
STATUS_NOT_RECEIVED = "not_received"
4646
STATUS_REPLIED = "replied"
4747
STATUS_ABORTED = "aborted"
48+
STATUS_DUP_REQUEST = "dup_request"
4849

4950
# Topics for Reliable Message
5051
TOPIC_RELIABLE_REQUEST = "RM.RELIABLE_REQUEST"
@@ -227,6 +228,7 @@ class ReliableMessage:
227228

228229
_topic_to_handle = {}
229230
_req_receivers = {} # tx id => receiver
231+
_req_completed = {} # tx id => expiration
230232
_enabled = False
231233
_executor = None
232234
_query_interval = 1.0
@@ -293,6 +295,9 @@ def _receive_request(cls, topic: str, request: Shareable, fl_ctx: FLContext):
293295
# no handler registered for this topic!
294296
cls.error(fl_ctx, f"no handler registered for request {rm_topic=}")
295297
return make_reply(ReturnCode.TOPIC_UNKNOWN)
298+
if cls._req_completed.get(tx_id):
299+
cls.debug(fl_ctx, "Completed tx_id received")
300+
return _status_reply(STATUS_DUP_REQUEST)
296301
receiver = cls._get_or_create_receiver(rm_topic, request, handler_f)
297302
cls.debug(fl_ctx, f"received request {rm_topic=}")
298303
return receiver.process(request, fl_ctx)
@@ -336,6 +341,7 @@ def release_request_receiver(cls, receiver: _RequestReceiver, fl_ctx: FLContext)
336341
337342
"""
338343
with cls._tx_lock:
344+
cls._register_completed_req(receiver.tx_id, receiver.tx_timeout)
339345
cls._req_receivers.pop(receiver.tx_id, None)
340346
cls.debug(fl_ctx, f"released request receiver of TX {receiver.tx_id}")
341347

@@ -679,3 +685,15 @@ def _query_result(
679685
cls.debug(fl_ctx, f"will retry query in {cls._query_interval} secs: {rc=} {status=} {op=}")
680686
else:
681687
cls.debug(fl_ctx, f"will retry query in {cls._query_interval} secs: {rc=}")
688+
689+
@classmethod
690+
def _register_completed_req(cls, tx_id, tx_timeout):
691+
# Remove expired entries, need to use a copy of the keys
692+
now = time.time()
693+
for key in list(cls._req_completed.keys()):
694+
expiration = cls._req_completed.get(key)
695+
if expiration and expiration < now:
696+
cls._req_completed.pop(key, None)
697+
698+
# Expire in 2 x tx_timeout
699+
cls._req_completed[tx_id] = now + 2 * tx_timeout

0 commit comments

Comments
 (0)