Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 7 additions & 15 deletions framework/py/flwr/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
from flwr.server.fleet_event_log_interceptor import FleetEventLogInterceptor
from flwr.supercore.address import parse_address, resolve_bind_address
from flwr.supercore.constant import FLWR_IN_MEMORY_DB_NAME
from flwr.supercore.ffs import FfsFactory
from flwr.supercore.grpc_health import add_args_health, run_health_server_grpc_no_tls
from flwr.supercore.object_store import ObjectStoreFactory
from flwr.supercore.update_check import warn_if_flwr_update_available
Expand Down Expand Up @@ -307,15 +306,18 @@ def run_superlink() -> None:
)
state_factory.state() # Force initialization before starting servers

# Initialize FfsFactory
ffs_factory = FfsFactory(args.storage_dir)
if "--storage-dir" in explicit_args:
log(
WARN,
"The `--storage-dir` argument is deprecated and has no effect in "
"SuperLink. FAB artifacts are stored in LinkState.",
)

# Start Control API
is_simulation = args.simulation
control_server: grpc.Server = run_control_api_grpc(
address=control_address,
state_factory=state_factory,
ffs_factory=ffs_factory,
objectstore_factory=objectstore_factory,
certificates=certificates,
authn_plugin=authn_plugin,
Expand All @@ -331,7 +333,6 @@ def run_superlink() -> None:
serverappio_server: grpc.Server = run_serverappio_api_grpc(
address=serverappio_address,
state_factory=state_factory,
ffs_factory=ffs_factory,
objectstore_factory=objectstore_factory,
certificates=None, # ServerAppIo API doesn't support SSL yet
)
Expand Down Expand Up @@ -378,7 +379,6 @@ def run_superlink() -> None:
args.ssl_keyfile,
args.ssl_certfile,
state_factory,
ffs_factory,
objectstore_factory,
num_workers,
),
Expand All @@ -398,7 +398,6 @@ def run_superlink() -> None:
fleet_server = _run_fleet_api_grpc_rere(
address=fleet_address,
state_factory=state_factory,
ffs_factory=ffs_factory,
objectstore_factory=objectstore_factory,
enable_supernode_auth=enable_supernode_auth,
certificates=certificates,
Expand All @@ -409,7 +408,6 @@ def run_superlink() -> None:
fleet_server = _run_fleet_api_grpc_adapter(
address=fleet_address,
state_factory=state_factory,
ffs_factory=ffs_factory,
objectstore_factory=objectstore_factory,
certificates=certificates,
)
Expand Down Expand Up @@ -553,7 +551,6 @@ def _try_obtain_fleet_event_log_writer_plugin() -> EventLogWriterPlugin | None:
def _run_fleet_api_grpc_rere( # pylint: disable=R0913, R0917
address: str,
state_factory: LinkStateFactory,
ffs_factory: FfsFactory,
objectstore_factory: ObjectStoreFactory,
enable_supernode_auth: bool,
certificates: tuple[bytes, bytes, bytes] | None,
Expand All @@ -563,7 +560,6 @@ def _run_fleet_api_grpc_rere( # pylint: disable=R0913, R0917
# Create Fleet API gRPC server
fleet_servicer = FleetServicer(
state_factory=state_factory,
ffs_factory=ffs_factory,
objectstore_factory=objectstore_factory,
enable_supernode_auth=enable_supernode_auth,
)
Expand All @@ -590,15 +586,13 @@ def _run_fleet_api_grpc_rere( # pylint: disable=R0913, R0917
def _run_fleet_api_grpc_adapter(
address: str,
state_factory: LinkStateFactory,
ffs_factory: FfsFactory,
objectstore_factory: ObjectStoreFactory,
certificates: tuple[bytes, bytes, bytes] | None,
) -> grpc.Server:
"""Run Fleet API (GrpcAdapter)."""
# Create Fleet API gRPC server
fleet_servicer = GrpcAdapterServicer(
state_factory=state_factory,
ffs_factory=ffs_factory,
objectstore_factory=objectstore_factory,
enable_supernode_auth=False,
)
Expand Down Expand Up @@ -628,7 +622,6 @@ def _run_fleet_api_rest(
ssl_keyfile: str | None,
ssl_certfile: str | None,
state_factory: LinkStateFactory,
ffs_factory: FfsFactory,
objectstore_factory: ObjectStoreFactory,
num_workers: int,
) -> None:
Expand All @@ -644,7 +637,6 @@ def _run_fleet_api_rest(

# See: https://www.starlette.io/applications/#accessing-the-app-instance
fast_api_app.state.STATE_FACTORY = state_factory
fast_api_app.state.FFS_FACTORY = ffs_factory
fast_api_app.state.OBJECTSTORE_FACTORY = objectstore_factory

uvicorn.run(
Expand Down Expand Up @@ -731,7 +723,7 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
)
parser.add_argument(
"--storage-dir",
help="The base directory to store the objects for the Flower File System.",
help="Deprecated and ignored by SuperLink.",
default=BASE_DIR,
)
parser.add_argument(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_rpc_completion() -> None:
"""Test if the GrpcAdapter servicer can handle all requests for Fleet API."""
# Prepare
all_method_names = (name for name in dir(FleetServicer) if name[0].isupper())
servicer = GrpcAdapterServicer(Mock(), Mock(), Mock(), Mock())
servicer = GrpcAdapterServicer(Mock(), Mock(), False)

# Execute
for method_name in all_method_names:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
from flwr.server.superlink.fleet.message_handler import message_handler
from flwr.server.superlink.linkstate import LinkStateFactory
from flwr.server.superlink.utils import abort_grpc_context
from flwr.supercore.ffs import FfsFactory
from flwr.supercore.inflatable.inflatable_object import UnexpectedObjectContentError
from flwr.supercore.object_store import ObjectStoreFactory

Expand All @@ -67,12 +66,10 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
def __init__(
self,
state_factory: LinkStateFactory,
ffs_factory: FfsFactory,
objectstore_factory: ObjectStoreFactory,
enable_supernode_auth: bool,
) -> None:
self.state_factory = state_factory
self.ffs_factory = ffs_factory
self.objectstore_factory = objectstore_factory
self.enable_supernode_auth = enable_supernode_auth
self.lock = threading.Lock()
Expand Down Expand Up @@ -262,7 +259,6 @@ def GetFab(
try:
res = message_handler.get_fab(
request=request,
ffs=self.ffs_factory.ffs(),
state=self.state_factory.state(),
store=self.objectstore_factory.store(),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Flower FleetServicer tests."""


import tempfile
import unittest
from unittest.mock import Mock, patch

Expand Down Expand Up @@ -72,7 +71,6 @@
NodeStatus,
RunType,
)
from flwr.supercore.ffs import FfsFactory
from flwr.supercore.inflatable.inflatable_object import (
get_all_nested_objects,
get_object_id,
Expand All @@ -90,17 +88,11 @@ class TestFleetServicer(unittest.TestCase): # pylint: disable=R0902, R0904

def setUp(self) -> None:
"""Initialize mock stub and server interceptor."""
# Create a temporary directory
self.temp_dir = tempfile.TemporaryDirectory() # pylint: disable=R1732
self.addCleanup(self.temp_dir.cleanup) # Ensures cleanup after test

objectstore_factory = ObjectStoreFactory()
state_factory = LinkStateFactory(
FLWR_IN_MEMORY_DB_NAME, NoOpFederationManager(), objectstore_factory
)
self.state = state_factory.state()
ffs_factory = FfsFactory(self.temp_dir.name)
self.ffs = ffs_factory.ffs()
self.store = objectstore_factory.store()
self.node_pk = b"fake public key"

Expand All @@ -109,7 +101,6 @@ def setUp(self) -> None:
self._server: grpc.Server = _run_fleet_api_grpc_rere(
FLEET_API_GRPC_RERE_DEFAULT_ADDRESS,
state_factory,
ffs_factory,
objectstore_factory,
self.enable_node_auth,
None,
Expand Down Expand Up @@ -175,7 +166,7 @@ def setUp(self) -> None:

def tearDown(self) -> None:
"""Clean up grpc server."""
self._server.stop(None)
self._server.stop(None).wait(timeout=2)

def _create_dummy_node(self, activate: bool = True) -> int:
"""Create a dummy node."""
Expand Down Expand Up @@ -509,7 +500,7 @@ def test_successful_get_fab_if_running(self) -> None:
# Prepare
node_id = self._create_dummy_node()
fab_content = b"content"
fab_hash = self.ffs.put(fab_content, {"meta": "data"})
fab_hash = self.state.put_fab(fab_content, {"meta": "data"})
run_id = self._create_dummy_run(fab_hash=fab_hash)

# Transition status to running. GetFab RPC is only allowed in running status.
Expand Down Expand Up @@ -550,7 +541,7 @@ def test_get_fab_not_successful_if_not_running(self, num_transitions: int) -> No
# Prepare
node_id = self._create_dummy_node()
fab_content = b"content"
fab_hash = self.ffs.put(fab_content, {"meta": "data"})
fab_hash = self.state.put_fab(fab_content, {"meta": "data"})
run_id = self._create_dummy_run(running=False, fab_hash=fab_hash)

self._transition_run_status(run_id, num_transitions)
Expand All @@ -563,7 +554,7 @@ def test_get_fab_permission_denied_if_node_not_in_federation(self) -> None:
# Prepare
node_id = self._create_dummy_node()
fab_content = b"content"
fab_hash = self.ffs.put(fab_content, {"meta": "data"})
fab_hash = self.state.put_fab(fab_content, {"meta": "data"})
run_id = self._create_dummy_run(fab_hash=fab_hash)

# Mock federation manager to exclude the node
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


import datetime
import tempfile
import unittest
from collections.abc import Callable
from typing import Any
Expand Down Expand Up @@ -70,7 +69,6 @@
from flwr.server.superlink.linkstate.linkstate_factory import LinkStateFactory
from flwr.server.superlink.linkstate.linkstate_test import create_res_message
from flwr.supercore.constant import FLWR_IN_MEMORY_DB_NAME, NOOP_FEDERATION, RunType
from flwr.supercore.ffs import FfsFactory
from flwr.supercore.object_store import ObjectStoreFactory
from flwr.supercore.primitives.asymmetric import (
generate_key_pairs,
Expand All @@ -97,16 +95,12 @@ def setUp(self) -> None:
FLWR_IN_MEMORY_DB_NAME, NoOpFederationManager(), objectstore_factory
)
self.state = state_factory.state()
self.tmp_dir = tempfile.TemporaryDirectory() # pylint: disable=R1732
ffs_factory = FfsFactory(self.tmp_dir.name)
self.ffs = ffs_factory.ffs()
self.store = objectstore_factory.store()

self._server_interceptor = NodeAuthServerInterceptor(state_factory)
self._server: grpc.Server = _run_fleet_api_grpc_rere(
FLEET_API_GRPC_RERE_DEFAULT_ADDRESS,
state_factory,
ffs_factory,
objectstore_factory,
self.enable_node_auth,
None,
Expand Down Expand Up @@ -172,9 +166,7 @@ def setUp(self) -> None:

def tearDown(self) -> None:
"""Clean up grpc server."""
self._server.stop(None)
# Cleanup the temp directory
self.tmp_dir.cleanup()
self._server.stop(None).wait(timeout=2)

def _make_metadata(self) -> list[Any]:
"""Create metadata with signature and timestamp."""
Expand Down Expand Up @@ -316,7 +308,7 @@ def _test_send_node_heartbeat(self, metadata: list[Any]) -> Any:

def _test_get_fab(self, metadata: list[Any]) -> Any:
"""Test GetFab."""
fab_hash = self.ffs.put(b"mock fab content", {})
fab_hash = self.state.put_fab(b"mock fab content", {})
node_id = self._create_node_in_linkstate()
run_id = self._create_dummy_run()
req = GetFabRequest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
from flwr.server.superlink.linkstate import LinkState
from flwr.server.superlink.utils import check_abort
from flwr.supercore.ffs import Ffs
from flwr.supercore.inflatable.inflatable_object import UnexpectedObjectContentError
from flwr.supercore.object_store import NoObjectInStoreError, ObjectStore

Expand Down Expand Up @@ -250,7 +249,7 @@ def get_run(


def get_fab(
request: GetFabRequest, ffs: Ffs, state: LinkState, store: ObjectStore
request: GetFabRequest, state: LinkState, store: ObjectStore
) -> GetFabResponse:
"""Get FAB."""
# Validate that the requesting SuperNode is part of the federation
Expand All @@ -266,7 +265,7 @@ def get_fab(
if abort_msg:
raise InvalidRunStatusException(abort_msg)

if result := ffs.get(request.hash_str):
if result := state.get_fab(request.hash_str):
fab = Fab(request.hash_str, result[0], result[1])
return GetFabResponse(fab=fab_to_proto(fab))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
from flwr.server.superlink.fleet.message_handler import message_handler
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
from flwr.supercore.ffs import Ffs, FfsFactory
from flwr.supercore.object_store import ObjectStore, ObjectStoreFactory

try:
Expand Down Expand Up @@ -224,15 +223,12 @@ async def get_run(request: GetRunRequest) -> GetRunResponse:
@rest_request_response(GetFabRequest)
async def get_fab(request: GetFabRequest) -> GetFabResponse:
"""GetRun."""
# Get ffs from app
ffs: Ffs = cast(FfsFactory, app.state.FFS_FACTORY).ffs()

# Get state from app
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()

# Handle message
return message_handler.get_fab(request=request, ffs=ffs, state=state, store=store)
return message_handler.get_fab(request=request, state=state, store=store)


@rest_request_response(ConfirmMessageReceivedRequest)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""In-memory LinkState implementation."""


import hashlib
import threading
from bisect import bisect_right
from collections import defaultdict
Expand Down Expand Up @@ -91,6 +92,7 @@ def __init__(
self.flwr_aid_to_run_ids: dict[str, set[int]] = defaultdict(set)

self.node_public_keys: set[bytes] = set()
self.fab_artifacts: dict[str, tuple[bytes, dict[str, str]]] = {}

self.lock = threading.RLock()
federation_manager.linkstate = self
Expand All @@ -101,6 +103,34 @@ def federation_manager(self) -> FederationManager:
"""Get the FederationManager instance."""
return self._federation_manager

def put_fab(self, content: bytes, verifications: dict[str, str]) -> str:
"""Store FAB content and verifications and return FAB hash."""
fab_hash = hashlib.sha256(content).hexdigest()
with self.lock:
self.fab_artifacts.setdefault(
fab_hash,
(content, dict(verifications)),
)
return fab_hash

def get_fab(self, fab_hash: str) -> tuple[bytes, dict[str, str]] | None:
"""Retrieve FAB content and verifications by hash."""
with self.lock:
entry = self.fab_artifacts.get(fab_hash)
if entry is None:
return None
content, verifications = entry
if hashlib.sha256(content).hexdigest() != fab_hash:
log(ERROR, "Corrupt FAB artifact in LinkState for hash %s", fab_hash)
return None
if not all(
isinstance(key, str) and isinstance(value, str)
for key, value in verifications.items()
):
log(ERROR, "Invalid FAB verification metadata for hash %s", fab_hash)
return None
return content, dict(verifications)

def store_message_ins(self, message: Message) -> str | None:
"""Store one Message."""
# Validate message
Expand Down
Loading
Loading