Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions doreisa/in_transit_analytic_actor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from collections.abc import Callable
from dataclasses import dataclass

import cloudpickle
import numpy as np
import ray
import ray.actor
import zmq.asyncio


@dataclass
class SendChunkRequest:
"""
Represents a request to send a chunk of data to the analytic node.
"""

array_name: str
chunk_position: tuple[int, ...]
nb_chunks_per_dim: tuple[int, ...]
nb_chunks_of_analysis_node: int
timestep: int
chunk: np.ndarray


@ray.remote
class InTransitAnalyticActor:
"""
Actor that runs a ZMQ server to receive data from the simulation nodes.
"""

def __init__(self, _fake_node_id: str | None = None) -> None:
self.node_id = _fake_node_id or ray.get_runtime_context().get_node_id()

# Get the head actor and the scheduling actor
self.head = ray.get_actor("simulation_head", namespace="doreisa")
self.scheduling_actor: ray.actor.ActorHandle = ray.get(
self.head.scheduling_actor.remote(self.node_id, is_fake_id=bool(_fake_node_id))
)

self.preprocessing_callbacks: dict[str, Callable] = ray.get(self.head.preprocessing_callbacks.remote())

self.context = zmq.asyncio.Context()

async def run_zmq_server(self, address: str):
socket = self.context.socket(zmq.REP)
socket.bind(f"tcp://{address}")

while True:
message = await socket.recv_pyobj()

if message == "get_preprocessing_callbacks":
# Send the preprocessing callbacks to the client
# Cloudpickle is needed since pickle fails to serialize the callbacks
await socket.send_pyobj(cloudpickle.dumps(self.preprocessing_callbacks))
continue

assert isinstance(message, SendChunkRequest)

await self.scheduling_actor.add_chunk.remote(
array_name=message.array_name,
timestep=message.timestep,
chunk_position=message.chunk_position,
dtype=message.chunk.dtype,
nb_chunks_per_dim=message.nb_chunks_per_dim,
nb_chunks_of_node=message.nb_chunks_of_analysis_node,
chunk=[ray.put(message.chunk)],
chunk_shape=message.chunk.shape,
)
await socket.send_pyobj(None)
48 changes: 46 additions & 2 deletions doreisa/simulation_node.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import logging
from typing import Callable

import cloudpickle
import numpy as np
import ray
import ray.actor
import zmq

from doreisa.in_transit_analytic_actor import SendChunkRequest


class Client:
Expand Down Expand Up @@ -40,7 +44,6 @@ def add_chunk(
nb_chunks_of_node: int,
timestep: int,
chunk: np.ndarray,
store_externally: bool = False,
) -> None:
"""
Make a chunk of data available to the analytic.
Expand All @@ -52,7 +55,6 @@ def add_chunk(
nb_chunks_of_node: The number of chunks sent by this node. The scheduling actor will
inform the head actor when all the chunks are ready.
chunk: The chunk of data.
store_externally: If True, the data is stored externally. TODO Not implemented yet.
"""
chunk = self.preprocessing_callbacks[array_name](chunk)

Expand All @@ -72,3 +74,45 @@ def add_chunk(

# Wait until the data is processed before returning to the simulation
ray.get(future)


class InTransitClient:
def __init__(self, analytic_node_address: str):
# Open a socket to communicate with the analytic node
context = zmq.Context()
self.socket = context.socket(zmq.REQ)
self.socket.connect(f"tcp://{analytic_node_address}")

# Get the preprocessing callbacks from the analytic node
self.socket.send_pyobj("get_preprocessing_callbacks")
self.preprocessing_callbacks: dict[str, Callable] = cloudpickle.loads(self.socket.recv_pyobj())

def add_chunk(
self,
array_name: str,
chunk_position: tuple[int, ...],
nb_chunks_per_dim: tuple[int, ...],
nb_chunks_of_analysis_node: int,
timestep: int,
chunk: np.ndarray,
) -> None:
"""
nb_chunks_of_analysis_node: The number of chunks that the ANALYTIC node will
receive for this array.
"""
chunk = self.preprocessing_callbacks[array_name](chunk)

# Send the data to the analytic node
self.socket.send_pyobj(
SendChunkRequest(
array_name=array_name,
chunk_position=chunk_position,
nb_chunks_per_dim=nb_chunks_per_dim,
nb_chunks_of_analysis_node=nb_chunks_of_analysis_node,
timestep=timestep,
chunk=chunk,
)
)

# Wait for the response from the analytic node
self.socket.recv_pyobj()
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "doreisa"
version = "0.3.3"
version = "0.4.0"
description = ""
authors = [{ name = "Adrien Vannson", email = "[email protected]" }]
requires-python = ">=3.12"
Expand All @@ -9,6 +9,7 @@ dependencies = [
"dask[dataframe] (==2024.6.0)",
"ray[default] (>=2.46.0,<3.0.0)",
"numpy (==1.26.4)", # TODO this was pinned for PDI, remove the pinning?
"pyzmq>=27.0.0", # Only needed for in transfer analytics. TODO move to group?
]

[build-system]
Expand Down
70 changes: 70 additions & 0 deletions tests/test_in_transit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import dask.array as da
import pytest
import ray
import ray.actor

from tests.utils import in_transit_worker, ray_cluster, wait_for_head_node # noqa: F401

NB_ITERATIONS = 10


@ray.remote(max_retries=0)
def head_script() -> None:
"""The head node checks that the values are correct"""
from doreisa.head_node import init
from doreisa.window_api import ArrayDefinition, run_simulation

init()

def simulation_callback(array: da.Array, timestep: int):
x = array.sum().compute()

assert x == 10 * timestep

run_simulation(
simulation_callback,
[ArrayDefinition("array")],
max_iterations=NB_ITERATIONS,
)


@pytest.mark.parametrize(
"nb_simulation_nodes, nb_analytic_nodes",
[(1, 1), (2, 2), (2, 4), (4, 2), (4, 4)],
)
def test_in_transit(nb_simulation_nodes: int, nb_analytic_nodes: int, ray_cluster) -> None: # noqa: F811
from doreisa.in_transit_analytic_actor import InTransitAnalyticActor

# Start the head actor
head_ref = head_script.remote()
wait_for_head_node()

# Start the analytic actors
analytic_actors: list[ray.actor.ActorHandle] = []
for i in range(nb_analytic_nodes):
actor = InTransitAnalyticActor.remote(_fake_node_id=f"node_{i}")

for j in range(4 // nb_analytic_nodes):
actor.run_zmq_server.remote(f"localhost:{8000 + i * (4 // nb_analytic_nodes) + j}")

analytic_actors.append(actor)

worker_refs = []
for rank in range(4):
worker_refs.append(
in_transit_worker.remote(
rank=rank,
position=(rank // 2, rank % 2),
chunks_per_dim=(2, 2),
nb_chunks_of_analytic_node=4 // nb_analytic_nodes,
chunk_size=(1, 1),
nb_iterations=NB_ITERATIONS,
analytic_node_address=f"localhost:{8000 + rank}",
)
)

ray.get([head_ref] + worker_refs)

# Check that the right number of scheduling actors were created
simulation_head = ray.get_actor("simulation_head", namespace="doreisa")
assert len(ray.get(simulation_head.list_scheduling_actors.remote())) == nb_analytic_nodes
26 changes: 25 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,28 @@ def simple_worker(
array = (rank + 1) * np.ones(chunk_size, dtype=dtype)

for i in range(nb_iterations):
client.add_chunk(array_name, position, chunks_per_dim, nb_chunks_of_node, i, i * array, store_externally=False)
client.add_chunk(array_name, position, chunks_per_dim, nb_chunks_of_node, i, i * array)


@ray.remote(num_cpus=0, max_retries=0)
def in_transit_worker(
*,
rank: int,
position: tuple[int, ...],
chunks_per_dim: tuple[int, ...],
nb_chunks_of_analytic_node: int,
chunk_size: tuple[int, ...],
nb_iterations: int,
array_name: str = "array",
dtype: np.dtype = np.int32, # type: ignore
analytic_node_address: str,
) -> None:
"""Worker node sending chunks of data"""
from doreisa.simulation_node import InTransitClient

client = InTransitClient(analytic_node_address=analytic_node_address)

array = (rank + 1) * np.ones(chunk_size, dtype=dtype)

for i in range(nb_iterations):
client.add_chunk(array_name, position, chunks_per_dim, nb_chunks_of_analytic_node, i, i * array)
Loading
Loading