diff --git a/docs/zmq.md b/docs/zmq.md index ba86911..833d682 100644 --- a/docs/zmq.md +++ b/docs/zmq.md @@ -91,5 +91,9 @@ For `predictor`, execute: ```sh predictor --zmq --zmq_source tmio -m write_async -f 100 ``` - +### Return frequency information to TMIO +To initiate `predicator` to return information to TMIO's prefetcher use +```sh +predictor --zmq --zmq_source tmio --zmq_port_reply 5556 -m read_sync +```
diff --git a/ftio/cli/predictor.py b/ftio/cli/predictor.py index 24d9722..1b34c5b 100644 --- a/ftio/cli/predictor.py +++ b/ftio/cli/predictor.py @@ -49,6 +49,7 @@ def main(args: list[str] = sys.argv) -> None: predictor_with_processes_zmq( shared_resources, args, + any("zmq_port_reply" in x for x in args), ) else: # prediction with processes and a callback mechanism diff --git a/ftio/parse/args.py b/ftio/parse/args.py index 3e5a0b0..b7a180a 100644 --- a/ftio/parse/args.py +++ b/ftio/parse/args.py @@ -327,6 +327,14 @@ def parse_args(argv: list, name="") -> argparse.Namespace: help="zmq port for communication", ) parser.set_defaults(zmq_port="5555") + parser.add_argument( + "--zmq_port_reply", + "--zmq_port_reply", + dest="zmq_port_reply", + type=str, + help="zmq port for communicating dominant frequency", + ) + parser.set_defaults(zmq_port_reply="5556") # filter arguments parser.add_argument( "--filter_type", diff --git a/ftio/prediction/processes_zmq.py b/ftio/prediction/processes_zmq.py index fd26b54..15aa1ed 100644 --- a/ftio/prediction/processes_zmq.py +++ b/ftio/prediction/processes_zmq.py @@ -13,6 +13,7 @@ from __future__ import annotations +import struct import subprocess import zmq @@ -20,7 +21,7 @@ from ftio.freq.helper import MyConsole from ftio.multiprocessing.async_process import handle_in_process, join_procs from ftio.parse.args import parse_args -from ftio.prediction.helper import export_extrap, print_data +from ftio.prediction.helper import export_extrap, get_dominant_and_conf, print_data from ftio.prediction.processes import prediction_process CONSOLE = MyConsole() @@ -28,8 +29,9 @@ def predictor_with_processes_zmq( - shard_resources, + shared_resources, args, + return_data: bool = False, ) -> None: """performs prediction in ProcessPoolExecuter. FTIO is a submitted future and probability is calculated as a callback @@ -41,13 +43,19 @@ def predictor_with_processes_zmq( # parse arguments tmp_args = parse_args(args) addr = tmp_args.zmq_address - port = tmp_args.zmq_port + port_in = tmp_args.zmq_port # bind the socket - socket = bind_socket(addr, port) + socket_in = setup_socket(addr, port_in, zmq.PULL) + socket_out = None + + if return_data: + port_out = tmp_args.zmq_port_reply + socket_out = setup_socket(addr, port_out, zmq.PUSH, False) + # can be extended to listen to multiple sockets poller = zmq.Poller() - poller.register(socket, zmq.POLLIN) + poller.register(socket_in, zmq.POLLIN) if "-zmq" not in args: args.extend(["--zmq"]) @@ -56,10 +64,20 @@ def predictor_with_processes_zmq( try: with CONSOLE.status("[green]started\n", spinner="arrow3") as status: while True: + pre_num_procs = len(procs) # join procs procs = join_procs(procs) + + if return_data and socket_out and pre_num_procs > len(procs): + CONSOLE.print("[cyan]Returning Results[/]") + data = get_dominant_and_conf(shared_resources.data[-1]) + CONSOLE.print(f"[cyan]Sending Frequency:{data[0]}[/]") + CONSOLE.print(f"[cyan]Sending Confidence:{data[1]}[/]") + packet = struct.pack("dd", data[0], data[1]) + socket_out.send(packet) + # get messages - msgs, ranks = receive_messages(socket, poller) + msgs, ranks = receive_messages(socket_in, poller) if not msgs: CONSOLE.print("[red]No messages[/]") @@ -69,21 +87,26 @@ def predictor_with_processes_zmq( procs.append( handle_in_process( - prediction_process, args=(shard_resources, args, msgs) + prediction_process, args=(shared_resources, args, msgs) ) ) except KeyboardInterrupt: - print_data(shard_resources.data) - export_extrap(shard_resources.data) + print_data(shared_resources.data) + export_extrap(shared_resources.data) print("-- done -- ") -def bind_socket(addr: str, port: str): +def setup_socket(addr: str, port: str, socket_type=zmq.PULL, bind: bool = True): """Bind the ZMQ socket, retrying with a corrected IP if necessary.""" context = zmq.Context() - socket = context.socket(zmq.PULL) + socket = context.socket(socket_type) + if not bind and addr == "*": + addr = "127.0.0.1" try: - socket.bind(f"tcp://{addr}:{port}") + if bind: + socket.bind(f"tcp://{addr}:{port}") + else: + socket.connect(f"tcp://{addr}:{port}") except zmq.error.ZMQError as e: CONSOLE.print(f"[yellow]Error encountered:\n{e}[/]") CONSOLE.print("[yellow]Wrong IP address. Attempting to correct...[/]") @@ -107,8 +130,10 @@ def bind_socket(addr: str, port: str): addr = output.splitlines()[0].split("/")[0] CONSOLE.print("[bold green]Corrected IP address:[/]", addr) - - socket.bind(f"tcp://{addr}:{port}") + if bind: + socket.bind(f"tcp://{addr}:{port}") + else: + socket.connect(f"tcp://{addr}:{port}") CONSOLE.print(f"[green]FTIO is running on: {addr}:{port}[/]")