Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions examples/federated-analytics/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
---
tags: [basic, tabular, federated analytics]
dataset: [artificial]
framework: [pandas]
---

# Federated Analytics with OMOP CDM using Flower

This example will show you how you can use Flower to run federated analytics workloads on distributed SQL databases, which are applicable to many biostatistics and healthcare use cases. You will use an artificial dataset generated in adherance to the [OMOP Common Data Model](https://www.ohdsi.org/data-standardization/), which uses the OHDSI standardized vocabularies widely adopted in clinical domains. You will also run this example with Flower's Deployment Engine to demonstrate how each SuperNode can be configured to connect to different PostgreSQL databases, respectively, while the database connection will be handled using SQLAlchemy and the `pyscopg` adapter which is the latest implementation of the PostgreSQL adapter for Python.

## Set up the project

### Clone the project

After cloning the project, this will create a new directory called `federated-analytics` containing the following files:

```shell
federated-analytics
├── db_init.sh # Defines an artificial OMOP CDM table
├── db_start.sh # Generates and starts PostgreSQL containers with OMOP CDM data
├── federated-analytics
│ ├── client_app.py # Defines your ClientApp
│ ├── server_app.py # Defines your ServerApp
│ └── task.py # Defines your database connection and data loading
├── pyproject.toml # Project metadata like dependencies and configs
└── README.md
```

### Install dependencies and project

Install the dependencies defined in `pyproject.toml` as well as the `federated-analytics` package.

```shell
# From a new python environment, run:
pip install -e .
```

### Start PostgreSQL databases

Run the following to start two PostgreSQL databases and initialize the dataset for each:

```shell
./db_start.sh
```

> [!NOTE]
> To start more than two databases, pass the desired number as the first argument to the script, e.g. `./db_start.sh 3`.

### Run with the Deployment Engine

For a basic execution of this federated analytics app, activate your environment and start the SuperLink process in insecure mode:

```shell
flower-superlink --insecure
```

Next, start two Supernodes and connect them to the SuperLink. You will need to specify different `--node-config` values so that each SuperNode will connect to different PostgreSQL databases.

```shell
flower-supernode \
--insecure \
--superlink 127.0.0.1:9092 \
--clientappio-api-address 127.0.0.1:9094 \
--node-config="db-port=5433"
```

```shell
flower-supernode \
--insecure \
--superlink 127.0.0.1:9092 \
--clientappio-api-address 127.0.0.1:9095 \
--node-config="db-port=5434"
```

Next, update the [`pyproject.toml`](./pyproject.toml) file to add a new federation configuration:

```toml
[tool.flwr.federations.local-deployment]
address = "127.0.0.1:9093"
insecure = true
```

Finally, run the Flower App and follow the `ServerApp` logs to track the execution of the run:

```shell
flwr run . federated-analytics --stream
```

You can also override some of the settings for your `ClientApp` and `ServerApp` defined in `pyproject.toml`. For example:

```shell
flwr run . federated-analytics --run-config "selected-features='age,bmi'" --stream
```

The steps above are adapted from this [how-to guide](https://flower.ai/docs/framework/how-to-run-flower-with-deployment-engine.html). After that, you might be intersted in setting up [secure TLS-enabled communications](https://flower.ai/docs/framework/how-to-enable-tls-connections.html) and [SuperNode authentication](https://flower.ai/docs/framework/how-to-authenticate-supernodes.html) in your federation.

If you are already familiar with how the Deployment Engine works, you may want to learn how to run it using Docker. Check out the [Flower with Docker](https://flower.ai/docs/framework/docker/index.html) documentation.
28 changes: 28 additions & 0 deletions examples/federated-analytics/db_init.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/usr/bin/env bash

SEED=${DB_SEED:-0.42}

psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL
CREATE TABLE person_measurements (
person_id INTEGER PRIMARY KEY,
age INTEGER,
bmi FLOAT,
systolic_bp INTEGER,
diastolic_bp INTEGER,
ldl_cholesterol FLOAT,
hba1c FLOAT
);

SELECT setseed($SEED);

INSERT INTO person_measurements
SELECT
gs AS person_id,
20 + (random() * 60)::INT AS age,
18 + (random() * 15) AS bmi,
100 + (random() * 40)::INT AS systolic_bp,
60 + (random() * 25)::INT AS diastolic_bp,
70 + (random() * 120) AS ldl_cholesterol,
4.5 + (random() * 4) AS hba1c
FROM generate_series(1, 100) gs;
EOSQL
31 changes: 31 additions & 0 deletions examples/federated-analytics/db_start.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/usr/bin/env bash

set -e

N=${1:-2} # number of PostgreSQL databases (default = 2)
BASE_PORT=5433

{
echo "services:"

for i in $(seq 1 "$N"); do
PORT=$((BASE_PORT + i - 1))
# Set a seed for each of the database for producing different random data
SEED=$(echo "scale=2; $i / 100" | bc)
cat <<EOF
postgres_$i:
image: postgres:18
container_name: postgres_$i
environment:
POSTGRES_USER: flwrlabs
POSTGRES_PASSWORD: flwrlabs
POSTGRES_DB: flwrlabs
DB_SEED: $SEED
ports:
- "$PORT:5432"
volumes:
- ./db_init.sh:/docker-entrypoint-initdb.d/init.sh:ro

EOF
done
} | docker compose -f - up -d
61 changes: 61 additions & 0 deletions examples/federated-analytics/federated_analytics/client_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""federated_analytics: A Flower / Federated Analytics app."""

import warnings

from flwr.app import Context, Message, MetricRecord, RecordDict
from flwr.clientapp import ClientApp

from federated_analytics.task import query_database

warnings.filterwarnings("ignore", category=UserWarning)

# Flower ClientApp
app = ClientApp()


@app.query()
def query(msg: Message, context: Context) -> Message:
"""Query PostgreSQL database and report aggregated results to `ServerApp`."""

# Get database connection details from node config
db_host: str = context.node_config.get("db-host", "localhost")
db_port: int = context.node_config.get("db-port", 5432)
db_name: str = context.node_config.get("db-name", "flwrlabs")
db_user: str = context.node_config.get("db-user", "flwrlabs")
db_password: str = context.node_config.get("db-password", "flwrlabs")
table_name: str = context.node_config.get("table-name", "person_measurements")

selected_features: list[str] = msg.content["config"]["selected_features"]
feature_aggregation: list[str] = msg.content["config"]["feature_aggregation"]

# Query database
df = query_database(
db_host=db_host,
db_port=db_port,
db_name=db_name,
db_user=db_user,
db_password=db_password,
table_name=table_name,
selected_features=selected_features,
)

# Compute aggregation metrics
metrics = {}
for feature in selected_features:
if feature not in df.columns:
raise ValueError(f"Feature '{feature}' not found in dataset columns.")

for agg in feature_aggregation:
if agg == "mean":
metrics[f"{feature}_{agg}_sum"] = sum(df[feature])
metrics[f"{feature}_{agg}_count"] = len(df[feature])
elif agg == "std":
metrics[f"{feature}_{agg}_sum"] = sum(df[feature])
metrics[f"{feature}_{agg}_count"] = len(df[feature])
metrics[f"{feature}_{agg}_sum_sqd"] = sum(df[feature] ** 2)
else:
print(f"Aggregation method '{agg}' not recognized.")

reply_content = RecordDict({"query_results": MetricRecord(metrics)})

return Message(reply_content, reply_to=msg)
85 changes: 85 additions & 0 deletions examples/federated-analytics/federated_analytics/server_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""federated_analytics: A Flower / Federated Analytics app."""

import json
import random
import time
from logging import INFO

from flwr.app import ConfigRecord, Context, Message, MessageType, RecordDict
from flwr.common.logger import log
from flwr.serverapp import Grid, ServerApp

from federated_analytics.task import aggregate_features

app = ServerApp()


@app.main()
def main(grid: Grid, context: Context) -> None:
"""This `ServerApp` construct a histogram from partial-histograms reported by the
`ClientApp`s."""

min_nodes = 1
fraction_sample = context.run_config["fraction-sample"]
selected_features = str(context.run_config["selected-features"]).split(",")
feature_aggregation = str(context.run_config["feature-aggregation"]).split(",")

log(INFO, "") # Add newline for log readability

log(INFO, "=" * 60)
log(INFO, "FEDERATED ANALYTICS CONFIGURATION".center(60))
log(INFO, "=" * 60)
log(INFO, "Selected features:")
for i, feature in enumerate(selected_features, 1):
log(INFO, " %d. %s", i, feature.strip())
log(INFO, "Feature aggregation methods: %s", ", ".join(feature_aggregation))
log(INFO, "=" * 60)

log(INFO, "") # Add newline for log readability

# Loop and wait until enough nodes are available.
all_node_ids: list[int] = []
while len(all_node_ids) < min_nodes:
all_node_ids = list(grid.get_node_ids())
if len(all_node_ids) >= min_nodes:
# Sample nodes
num_to_sample = int(len(all_node_ids) * fraction_sample)
node_ids = random.sample(all_node_ids, num_to_sample)
break
log(INFO, "Waiting for nodes to connect...")
time.sleep(2)

log(INFO, "Sampled %s nodes (out of %s)", len(node_ids), len(all_node_ids))

# Create messages
config = ConfigRecord(
{
"selected_features": selected_features,
"feature_aggregation": feature_aggregation,
}
)
recorddict = RecordDict({"config": config})
messages = []
for node_id in node_ids: # one message for each node
message = Message(
content=recorddict,
message_type=MessageType.QUERY, # target `query` method in ClientApp
dst_node_id=node_id,
group_id="1",
)
messages.append(message)

# Send messages and wait for all results
replies = grid.send_and_receive(messages)
log(INFO, "Received %s/%s results", len(list(replies)), len(messages))

aggregated_stats = aggregate_features(
replies, selected_features, feature_aggregation
)

# Display final aggregated stats
print("\n" + "=" * 40)
print("FINAL AGGREGATED STATISTICS".center(40))
print("=" * 40)
print(json.dumps(aggregated_stats, indent=2))
print("=" * 40 + "\n")
Loading