|
45 | 45 | STATUS_NOT_RECEIVED = "not_received" |
46 | 46 | STATUS_REPLIED = "replied" |
47 | 47 | STATUS_ABORTED = "aborted" |
| 48 | +STATUS_DUP_REQUEST = "dup_request" |
48 | 49 |
|
49 | 50 | # Topics for Reliable Message |
50 | 51 | TOPIC_RELIABLE_REQUEST = "RM.RELIABLE_REQUEST" |
@@ -227,6 +228,7 @@ class ReliableMessage: |
227 | 228 |
|
228 | 229 | _topic_to_handle = {} |
229 | 230 | _req_receivers = {} # tx id => receiver |
| 231 | + _req_completed = {} # tx id => expiration |
230 | 232 | _enabled = False |
231 | 233 | _executor = None |
232 | 234 | _query_interval = 1.0 |
@@ -293,6 +295,9 @@ def _receive_request(cls, topic: str, request: Shareable, fl_ctx: FLContext): |
293 | 295 | # no handler registered for this topic! |
294 | 296 | cls.error(fl_ctx, f"no handler registered for request {rm_topic=}") |
295 | 297 | return make_reply(ReturnCode.TOPIC_UNKNOWN) |
| 298 | + if cls._req_completed.get(tx_id): |
| 299 | + cls.debug(fl_ctx, "Completed tx_id received") |
| 300 | + return _status_reply(STATUS_DUP_REQUEST) |
296 | 301 | receiver = cls._get_or_create_receiver(rm_topic, request, handler_f) |
297 | 302 | cls.debug(fl_ctx, f"received request {rm_topic=}") |
298 | 303 | return receiver.process(request, fl_ctx) |
@@ -336,6 +341,7 @@ def release_request_receiver(cls, receiver: _RequestReceiver, fl_ctx: FLContext) |
336 | 341 |
|
337 | 342 | """ |
338 | 343 | with cls._tx_lock: |
| 344 | + cls._register_completed_req(receiver.tx_id, receiver.tx_timeout) |
339 | 345 | cls._req_receivers.pop(receiver.tx_id, None) |
340 | 346 | cls.debug(fl_ctx, f"released request receiver of TX {receiver.tx_id}") |
341 | 347 |
|
@@ -679,3 +685,15 @@ def _query_result( |
679 | 685 | cls.debug(fl_ctx, f"will retry query in {cls._query_interval} secs: {rc=} {status=} {op=}") |
680 | 686 | else: |
681 | 687 | cls.debug(fl_ctx, f"will retry query in {cls._query_interval} secs: {rc=}") |
| 688 | + |
| 689 | + @classmethod |
| 690 | + def _register_completed_req(cls, tx_id, tx_timeout): |
| 691 | + # Remove expired entries, need to use a copy of the keys |
| 692 | + now = time.time() |
| 693 | + for key in list(cls._req_completed.keys()): |
| 694 | + expiration = cls._req_completed.get(key) |
| 695 | + if expiration and expiration < now: |
| 696 | + cls._req_completed.pop(key, None) |
| 697 | + |
| 698 | + # Expire in 2 x tx_timeout |
| 699 | + cls._req_completed[tx_id] = now + 2 * tx_timeout |
0 commit comments