1414
1515import logging
1616import time
17- from typing import Any , Optional , Tuple
17+ from typing import Any , List , Optional , Tuple
1818
1919from nvflare .fuel .utils .pipe .pipe import Message , Pipe
2020from nvflare .fuel .utils .pipe .pipe_handler import PipeHandler , Topic
@@ -40,31 +40,33 @@ class ExchangePeerGoneException(DataExchangeException):
4040 pass
4141
4242
43- class ModelExchanger :
43+ class DataExchanger :
4444 def __init__ (
4545 self ,
46+ supported_topics : List [str ],
4647 pipe : Pipe ,
4748 pipe_name : str = "pipe" ,
48- topic : str = "data" ,
4949 get_poll_interval : float = 0.5 ,
5050 read_interval : float = 0.1 ,
5151 heartbeat_interval : float = 5.0 ,
5252 heartbeat_timeout : float = 30.0 ,
5353 ):
54- """Initializes the ModelExchanger .
54+ """Initializes the DataExchanger .
5555
5656 Args:
57+ supported_topics (list[str]): Supported topics for data exchange. This allows the sender and receiver to identify
58+ the purpose or content of the data being exchanged.
5759 pipe (Pipe): The pipe used for data exchange.
5860 pipe_name (str): Name of the pipe. Defaults to "pipe".
59- topic (str): Topic for data exchange. Defaults to "data".
6061 get_poll_interval (float): Interval for checking if the other side has sent data. Defaults to 0.5.
6162 read_interval (float): Interval for reading from the pipe. Defaults to 0.1.
6263 heartbeat_interval (float): Interval for sending heartbeat to the peer. Defaults to 5.0.
6364 heartbeat_timeout (float): Timeout for waiting for a heartbeat from the peer. Defaults to 30.0.
6465 """
6566 self .logger = logging .getLogger (self .__class__ .__name__ )
6667 self ._req_id : Optional [str ] = None
67- self ._topic = topic
68+ self .current_topic : Optional [str ] = None
69+ self ._supported_topics = supported_topics
6870
6971 pipe .open (pipe_name )
7072 self .pipe_handler = PipeHandler (
@@ -76,51 +78,56 @@ def __init__(
7678 self .pipe_handler .start ()
7779 self ._get_poll_interval = get_poll_interval
7880
79- def submit_model (self , model : Any ) -> None :
80- """Submits a model for exchange.
81+ def submit_data (self , data : Any ) -> None :
82+ """Submits a data for exchange.
8183
8284 Args:
83- model (Any): The model to be submitted.
85+ data (Any): The data to be submitted.
8486
8587 Raises:
86- DataExchangeException: If there is no request ID available (needs to pull model from server first).
88+ DataExchangeException: If there is no request ID available (needs to pull data from server first).
8789 """
8890 if self ._req_id is None :
89- raise DataExchangeException ("need to pull a model first." )
90- self ._send_reply (data = model , req_id = self ._req_id )
91+ raise DataExchangeException ("Missing req_id, need to pull a data first." )
9192
92- def receive_model (self , timeout : Optional [float ] = None ) -> Any :
93- """Receives a model.
93+ if self .current_topic is None :
94+ raise DataExchangeException ("Missing current_topic, need to pull a data first." )
95+
96+ self ._send_reply (data = data , topic = self .current_topic , req_id = self ._req_id )
97+
98+ def receive_data (self , timeout : Optional [float ] = None ) -> Tuple [str , Any ]:
99+ """Receives a data.
94100
95101 Args:
96- timeout (Optional[float]): Timeout for waiting to receive a model . Defaults to None.
102+ timeout (Optional[float]): Timeout for waiting to receive a data . Defaults to None.
97103
98104 Returns:
99- Any : The received model .
105+ A tuple of (topic, data) : The received data .
100106
101107 Raises:
102108 ExchangeTimeoutException: If the data cannot be received within the specified timeout.
103109 ExchangeAbortException: If the other endpoint of the pipe requests to abort.
104110 ExchangeEndException: If the other endpoint has ended.
105111 ExchangePeerGoneException: If the other endpoint is gone.
106112 """
107- model , req_id = self ._receive_request (timeout )
108- self ._req_id = req_id
109- return model
113+ msg = self ._receive_request (timeout )
114+ self ._req_id = msg .msg_id
115+ self .current_topic = msg .topic
116+ return msg .topic , msg .data
110117
111118 def finalize (self , close_pipe : bool = True ) -> None :
112119 if self .pipe_handler is None :
113120 raise RuntimeError ("PipeMonitor is not initialized." )
114121 self .pipe_handler .stop (close_pipe = close_pipe )
115122
116- def _receive_request (self , timeout : Optional [float ] = None ) -> Tuple [ Any , str ] :
123+ def _receive_request (self , timeout : Optional [float ] = None ) -> Message :
117124 """Receives a request.
118125
119126 Args:
120127 timeout: how long to wait for the request to come.
121128
122129 Returns:
123- A tuple of (data, request id) .
130+ A Message .
124131
125132 Raises:
126133 ExchangeTimeoutException: If can't receive data within timeout seconds.
@@ -138,24 +145,21 @@ def _receive_request(self, timeout: Optional[float] = None) -> Tuple[Any, str]:
138145 self .pipe_handler .notify_abort (msg )
139146 raise ExchangeTimeoutException (f"get data timeout after { timeout } secs" )
140147 elif msg .topic == Topic .ABORT :
141- raise ExchangeAbortException ("the other end is aborted " )
148+ raise ExchangeAbortException ("the other end ask to abort " )
142149 elif msg .topic == Topic .END :
143- raise ExchangeEndException (
144- f"received { msg .topic } : { msg .data } while waiting for result for { self ._topic } "
145- )
150+ raise ExchangeEndException (f"received msg: '{ msg } ' while waiting for requests" )
146151 elif msg .topic == Topic .PEER_GONE :
147- raise ExchangePeerGoneException (
148- f"received { msg .topic } : { msg .data } while waiting for result for { self ._topic } "
149- )
150- elif msg .topic == self ._topic :
151- return msg .data , msg .msg_id
152+ raise ExchangePeerGoneException (f"received msg: '{ msg } ' while waiting for requests" )
153+ elif msg .topic in self ._supported_topics :
154+ return msg
152155 time .sleep (self ._get_poll_interval )
153156
154- def _send_reply (self , data : Any , req_id : str , timeout : Optional [float ] = None ) -> bool :
157+ def _send_reply (self , data : Any , topic : str , req_id : str , timeout : Optional [float ] = None ) -> bool :
155158 """Sends a reply.
156159
157160 Args:
158161 data: The data exchange object to be sent.
162+ topic: message topic.
159163 req_id: request ID.
160164 timeout: how long to wait for the peer to read the data.
161165 If not specified, return False immediately.
@@ -165,6 +169,6 @@ def _send_reply(self, data: Any, req_id: str, timeout: Optional[float] = None) -
165169 """
166170 if self .pipe_handler is None :
167171 raise RuntimeError ("PipeMonitor is not initialized." )
168- msg = Message .new_reply (topic = self . _topic , data = data , req_msg_id = req_id )
172+ msg = Message .new_reply (topic = topic , data = data , req_msg_id = req_id )
169173 has_been_read = self .pipe_handler .send_to_peer (msg , timeout )
170174 return has_been_read
0 commit comments