diff --git a/examples/quickstart-diffusion/README.md b/examples/quickstart-diffusion/README.md new file mode 100644 index 000000000000..151d41cf4c00 --- /dev/null +++ b/examples/quickstart-diffusion/README.md @@ -0,0 +1,191 @@ +--- +tags: [quickstart, diffusion, vision, federated-learning] +dataset: [Oxford-Flowers] +framework: [diffusers, peft, flower] +--- + + +# Federated Diffusion Model Training with Flower (Quickstart Example) + +This example demonstrates how to train a **Diffusion Model** (based on [Stable Diffusion v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)) in a **Federated Learning (FL)** environment using [Flower](https://flower.ai/). +The training uses **Low-Rank Adaptation (LoRA)** to enable lightweight fine-tuning of a diffusion model in a distributed setup, even with limited compute resources. +The example uses the Oxford Flowers dataset, a collection of RGB flower images commonly used for training image-generation models. + +Prompt: +A realistic image of a blooming yellow daffodil in natural sunlight + +The model was trained using 400 images for training, 80 images for evaluation, and fine-tuned for 5 federated rounds. +Below is a sample output image generated using the final LoRA-adapted model. + +![](_static/federated_diffusion_sample.png) +--- +## Overview + +In this example: + +- Each client fine-tunes **only the LoRA parameters** of the Stable Diffusion UNet. +- The **Oxford-Flowers** dataset is partitioned among multiple clients using Flower Datasets. +- The server performs **FedAvg aggregation** on the LoRA weights after each round. +- After all rounds, the aggregated LoRA adapter is saved into `final_lora_model/` and can be merged with the base model for image generation. + +This provides a clean example of how **federated diffusion fine-tuning** can be performed using Diffusers + PEFT + Flower. + +--- +## Set up the project + +### Clone the project + +Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +git clone --depth=1 https://github.com/adap/flower.git _tmp \ + && mv _tmp/examples/quickstart-diffusion . \ + && rm -rf _tmp && cd quickstart-diffusion +``` + +After cloning, your directory will look like this: + +```shell +quickstart-diffusion +├── diffusion_example +│ ├── __init__.py +│ ├── client_app.py # Defines your ClientApp logic +│ ├── server_app.py # Defines the ServerApp and strategy +│ └── task.py # Model setup, data loading, and training functions +├── pyproject.toml # Project metadata and dependencies +└── README.md # This file + + +``` +### Install dependencies and project + +Install the dependencies defined in `pyproject.toml` as well as the `diffusion_example` package. + +```bash +pip install -e . +``` +## Run the Example + +You can run your Flower project in both _simulation_ and _deployment_ mode without making changes to the code. If you are starting with Flower, we recommend you using the _simulation_ mode as it requires fewer components to be launched manually. By default, `flwr run` will make use of the Simulation Engine. + +### Run with the Simulation Engine + +> [!TIP] +> This example runs faster when the `ClientApp`s have access to a GPU. If your system has one, you can make use of it by configuring the `backend.client-resources` component in `pyproject.toml`. If you want to try running the example with GPU right away, use the `local-simulation-gpu` federation as shown below. Check the [Simulation Engine documentation](https://flower.ai/docs/framework/how-to-run-simulations.html) to learn more about Flower simulations and how to optimize them. + +```bash +# Run with the default federation (CPU only) +flwr run . +``` + +Run the project in the `local-simulation-gpu` federation that allocates both CPU and GPU resources to each `ClientApp`. +Since diffusion models are memory-intensive, we recommend running **one client at a time** or limiting parallelism to **1–2 clients per GPU** (each client may use 3–6 GB of VRAM depending on image size and model configuration). +You can modify the level of parallelism or memory allocation in the `client-resources` section of your `pyproject.toml` file to fit your system’s GPU capacity. + + +```bash +# Run with the `local-simulation-gpu` federation +flwr run . local-simulation-gpu +``` + +You can also override some of the settings for your `ClientApp` and `ServerApp` defined in `pyproject.toml`. For example + +```bash +flwr run --run-config "num-server-rounds=5 fraction-train=0.1" +``` + +### Result output + +Example of training step results for each client and corresponding server logs: + + +```bash +Loading project configuration... +Success +Loading pipeline components...: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 19.94it/s] +Trainable parameters: 797184 + + Starting federated diffusion training for 5 rounds... + Using base model: runwayml/stable-diffusion-v1-5 + Training LoRA parameters only (256 layers) +INFO : Starting FedAvg strategy: +INFO : ├── Number of rounds: 5 +INFO : ├── ArrayRecord (3.07 MB) +INFO : ├── ConfigRecord (train): (empty!) +INFO : ├── ConfigRecord (evaluate): (empty!) +INFO : ├──> Sampling: +INFO : │ ├──Fraction: train (1.00) | evaluate ( 1.00) +INFO : │ ├──Minimum nodes: train (2) | evaluate (2) +INFO : │ └──Minimum available nodes: 2 +INFO : └──> Keys in records: +INFO : ├── Weighted by: 'num-examples' +INFO : ├── ArrayRecord key: 'arrays' +INFO : └── ConfigRecord key: 'config' +INFO : +INFO : +INFO : [ROUND 1/5] +INFO : configure_train: Sampled 2 nodes (out of 2) +(ClientAppActor pid=73271) /opt/homebrew/Caskroom/miniconda/base/envs/diffusion-env/lib/python3.9/site-packages/flwr_datasets/utils.py:109: UserWarning: The currently tested dataset are ['mnist', 'ylecun/mnist', 'cifar10', 'uoft-cs/cifar10', 'fashion_mnist', 'zalando-datasets/fashion_mnist', 'sasha/dog-food', 'zh-plus/tiny-imagenet', 'scikit-learn/adult-census-income', 'cifar100', 'uoft-cs/cifar100', 'svhn', 'ufldl-stanford/svhn', 'sentiment140', 'stanfordnlp/sentiment140', 'speech_commands', 'LIUM/tedlium', 'flwrlabs/femnist', 'flwrlabs/ucf101', 'flwrlabs/ambient-acoustic-context', 'jlh/uci-mushrooms', 'Mike0307/MNIST-M', 'flwrlabs/usps', 'scikit-learn/iris', 'flwrlabs/pacs', 'flwrlabs/cinic10', 'flwrlabs/caltech101', 'flwrlabs/office-home', 'flwrlabs/fed-isic2019']. Given: nkirschi/oxford-flowers. +(ClientAppActor pid=73271) Partition 0: 400 training samples, 80 test samples +Loading pipeline components...: 0%| | 0/6 [00:00 Aggregated MetricRecord: {'loss': 0.267623713389039} +INFO : configure_evaluate: Sampled 2 nodes (out of 2) +(ClientAppActor pid=73271) Partition 1: 400 training samples, 80 test samples +Loading pipeline components...: 0%| | 0/6 [00:00 Aggregated MetricRecord: {'loss': 0.3734127514064312} +INFO : +INFO : [ROUND 2/5] +. +. +INFO : [ROUND 3/5] +. +. +INFO : [ROUND 4/5] +. +. +INFO : [ROUND 5/5] +. +. +INFO : +INFO : Strategy execution finished in 35154.13s +INFO : +INFO : Final results: +INFO : +INFO : Global Arrays: +INFO : ArrayRecord (3.072 MB) +INFO : +INFO : Aggregated ClientApp-side Train Metrics: +INFO : { 1: {'loss': '2.6762e-01'}, +INFO : 2: {'loss': '2.4918e-01'}, +INFO : 3: {'loss': '2.5092e-01'}, +INFO : 4: {'loss': '2.5719e-01'}, +INFO : 5: {'loss': '2.4323e-01'}} +INFO : +INFO : Aggregated ClientApp-side Evaluate Metrics: +INFO : { 1: {'loss': '3.7341e-01'}, +INFO : 2: {'loss': '3.2973e-01'}, +INFO : 3: {'loss': '3.5631e-01'}, +INFO : 4: {'loss': '3.6203e-01'}, +INFO : 5: {'loss': '3.3863e-01'}} +INFO : +INFO : ServerApp-side Evaluate Metrics: +INFO : {} +INFO : +Saving final federated LoRA adapter... +Saved final LoRA model at: final_lora_model + + diff --git a/examples/quickstart-diffusion/_static/federated_diffusion_sample.png b/examples/quickstart-diffusion/_static/federated_diffusion_sample.png new file mode 100644 index 000000000000..913f05b66c86 Binary files /dev/null and b/examples/quickstart-diffusion/_static/federated_diffusion_sample.png differ diff --git a/examples/quickstart-diffusion/diffusion_example/__init__.py b/examples/quickstart-diffusion/diffusion_example/__init__.py new file mode 100644 index 000000000000..23a09c9602bb --- /dev/null +++ b/examples/quickstart-diffusion/diffusion_example/__init__.py @@ -0,0 +1 @@ +"""diffusion_example: A Flower / Diffusion Model Training app.""" diff --git a/examples/quickstart-diffusion/diffusion_example/client_app.py b/examples/quickstart-diffusion/diffusion_example/client_app.py new file mode 100644 index 000000000000..275868eda551 --- /dev/null +++ b/examples/quickstart-diffusion/diffusion_example/client_app.py @@ -0,0 +1,111 @@ +"""Memory-optimized Flower client for LoRA fine-tuning of Stable Diffusion.""" + +import warnings +import torch +import gc +from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict +from flwr.clientapp import ClientApp +from peft import get_peft_model_state_dict +from diffusion_example.task import get_lora_model, train_lora_step, load_data, evaluate_lora_step + +warnings.filterwarnings("ignore", category=FutureWarning) + +# Flower ClientApp +app = ClientApp() + +@app.train() +def train(msg: Message, context: Context) -> Message: + """Perform one local training step with memory optimizations.""" + + # Clear memory + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + base_model = context.run_config["base-model"] + partition_id = context.node_config["partition-id"] + num_partitions = context.node_config["num-partitions"] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + image_size = int(context.run_config.get("image-size", 64)) + batch_size = int(context.run_config.get("batch-size", 8)) + + train_loader, _ = load_data(partition_id, num_partitions, image_size=image_size, batch_size=batch_size) + + pipe, model_dtype = get_lora_model(base_model, device) + + # Restore only LoRA parameters + arrays = msg.content["arrays"] + current_state = pipe.unet.state_dict() + + # Update only the LoRA parameters, keep base model frozen + for key in arrays.keys(): + if key in current_state: + current_state[key] = torch.tensor(arrays[key]) + + pipe.unet.load_state_dict(current_state, strict=False) + + loss_train = train_lora_step(pipe, train_loader, device, model_dtype) + + # Extract only LoRA parameters for sending back + lora_state_dict = get_peft_model_state_dict(pipe.unet) + + del pipe + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + metrics = MetricRecord({ + "loss": float(loss_train), + "num-examples": len(train_loader) + }) + + content = RecordDict({ + "arrays": ArrayRecord(lora_state_dict), + "metrics": metrics + }) + + return Message(content=content, reply_to=msg) + + + +@app.evaluate() +def evaluate(msg: Message, context: Context) -> Message: + """Evaluate with memory optimizations.""" + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + base_model = context.run_config["base-model"] + partition_id = context.node_config["partition-id"] + num_partitions = context.node_config["num-partitions"] + + image_size = int(context.run_config.get("image-size", 64)) + batch_size = int(context.run_config.get("batch-size", 8)) + + _, test_loader = load_data(partition_id, num_partitions, image_size=image_size, batch_size=batch_size) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + pipe, model_dtype = get_lora_model(base_model, device) + + # Restore parameters + arrays = msg.content["arrays"] + current_state = pipe.unet.state_dict() + for key in arrays.keys(): + if key in current_state: + current_state[key] = torch.tensor(arrays[key]) + + pipe.unet.load_state_dict(current_state, strict=False) + + loss_val = evaluate_lora_step(pipe, test_loader, device, model_dtype) + + del pipe + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + metrics = MetricRecord({ + "loss": float(loss_val), + "num-examples": len(test_loader) + }) + + return Message(content=RecordDict({"metrics": metrics}), reply_to=msg) diff --git a/examples/quickstart-diffusion/diffusion_example/server_app.py b/examples/quickstart-diffusion/diffusion_example/server_app.py new file mode 100644 index 000000000000..1ce4601146f4 --- /dev/null +++ b/examples/quickstart-diffusion/diffusion_example/server_app.py @@ -0,0 +1,63 @@ +"""Memory-optimized Flower server for federated LoRA fine-tuning.""" +import os +import torch +from flwr.app import ArrayRecord, Context +from flwr.serverapp import Grid, ServerApp +from flwr.serverapp.strategy import FedAvg +from peft import get_peft_model_state_dict, LoraConfig +from diffusion_example.task import get_lora_model, generate_image + +app = ServerApp() + +@app.main() +def main(grid: Grid, context: Context) -> None: + # Load base model and extract initial LoRA parameters only + base_model = context.run_config["base-model"] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Load model just to get LoRA structure + pipe, torch_dtype = get_lora_model(base_model, device) + + peft_state_dict = get_peft_model_state_dict(pipe.unet) + initial_arrays = ArrayRecord(peft_state_dict) + + del pipe + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Configure training strategy + fraction_train = float(context.run_config["fraction-train"]) + fraction_evaluate = float(context.run_config["fraction-evaluate"]) + num_rounds = int(context.run_config["num-server-rounds"]) + + strategy = FedAvg( + fraction_train=fraction_train, + fraction_evaluate=fraction_evaluate, + ) + + # Start the federated training + print(f"\n Starting federated diffusion training for {num_rounds} rounds...") + print(f" Using base model: {base_model}") + print(f" Training LoRA parameters only ({len(peft_state_dict)} layers)") + + result = strategy.start( + grid=grid, + initial_arrays=initial_arrays, + num_rounds=num_rounds, + timeout=7200, + ) + + final_state = result.arrays.to_torch_state_dict() + save_dir = "final_lora_model" + os.makedirs(save_dir, exist_ok=True) + torch.save(final_state, os.path.join(save_dir, "adapter_model.bin")) + config = LoraConfig( + r=4, + lora_alpha=8, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], # Only attention layers + lora_dropout=0.0, + bias="none", + ) + config.save_pretrained(save_dir) + print(f"Saved final LoRA model at: {save_dir}") + generate_image(context, device, torch_dtype) diff --git a/examples/quickstart-diffusion/diffusion_example/task.py b/examples/quickstart-diffusion/diffusion_example/task.py new file mode 100644 index 000000000000..3b0a6f3bf220 --- /dev/null +++ b/examples/quickstart-diffusion/diffusion_example/task.py @@ -0,0 +1,243 @@ +import os +import time +import numpy as np +import torch +import torch.nn as nn +from PIL import Image +from diffusers import StableDiffusionPipeline +from flwr_datasets import FederatedDataset +from peft import LoraConfig, get_peft_model, PeftModel +from flwr_datasets.partitioner import IidPartitioner +from torch.utils.data import DataLoader +from torchvision import transforms + +import warnings +warnings.filterwarnings("ignore", category=FutureWarning) +fds = None + + +def collate_fn(batch): + pixel_values = torch.stack([torch.as_tensor(item["pixel_values"]) for item in batch]) + return {"pixel_values": pixel_values} + +def load_data( + partition_id: int, + num_partitions: int, + image_size: int = 64, + batch_size: int = 8 +) -> tuple[DataLoader, DataLoader]: + """Load Oxford Flowers data for diffusion model training.""" + + global fds + if fds is None: + partitioner = IidPartitioner(num_partitions=num_partitions) + fds = FederatedDataset( + dataset="nkirschi/oxford-flowers", + partitioners={"train": partitioner} + ) + + partition = fds.load_partition(partition_id) + + # --- Image preprocessing for RGB flowers --- + transform = transforms.Compose([ + transforms.Resize((image_size, image_size)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) + ]) + + def transform_function(examples): + key = "image" if "image" in examples else "img" + images = [] + for img in examples[key]: + if not isinstance(img, Image.Image): + img = Image.fromarray(img) + img = img.convert("RGB") # ensure 3 channels + images.append(transform(img)) + return {"pixel_values": images} + + partition = partition.map( + transform_function, + batched=True, + remove_columns=list(partition.column_names) + ) + + # Split into train/test (20% test) + partition_train_test = partition.train_test_split(test_size=0.20, seed=42) + + # Limit dataset size for quick federated demo + def limit_dataset(dataset, n_samples, seed=None): + if seed is None: + seed = int(time.time() * 1000) % 2**32 + rng = np.random.default_rng(seed) + n = min(n_samples, len(dataset)) + indices = rng.choice(len(dataset), size=n, replace=False) + return dataset.select(indices) + + partition_train_test["train"] = limit_dataset(partition_train_test["train"], 400) + partition_train_test["test"] = limit_dataset(partition_train_test["test"], 80) + + trainload = DataLoader( + partition_train_test["train"], + shuffle=True, + batch_size=batch_size, + num_workers=0, + collate_fn=collate_fn + ) + + testload = DataLoader( + partition_train_test["test"], + batch_size=batch_size, + num_workers=0, + collate_fn=collate_fn + ) + print(f"Partition {partition_id}: {len(partition_train_test['train'])} training samples, "f"{len(partition_train_test['test'])} test samples") + + return trainload, testload + + +def get_lora_model(base_model: str, device: torch.device): + """Load Stable Diffusion model with memory-optimized LoRA adapters.""" + torch_dtype = torch.float16 if device.type == "cuda" else torch.float32 + + pipe = StableDiffusionPipeline.from_pretrained( + base_model, + torch_dtype=torch_dtype, + safety_checker=None, + requires_safety_checker=False, + use_safetensors=True, + ) + + pipe = enable_memory_efficient_attention(pipe) + pipe.to(device) + + lora_config = LoraConfig( + r=4, + lora_alpha=8, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], # Only attention layers + lora_dropout=0.0, + bias="none", + ) + + pipe.unet = get_peft_model(pipe.unet, lora_config) + print(f"Trainable parameters: {sum(p.numel() for p in pipe.unet.parameters() if p.requires_grad)}") + return pipe, torch_dtype + + +def enable_memory_efficient_attention(pipe): + """Enable memory-efficient attention mechanisms.""" + try: + pipe.unet.set_use_memory_efficient_attention_xformers(True) + except: + try: + pipe.unet.enable_attention_slicing() + except: + pass + + # Enable CPU offloading for VAE and text encoder + try: + pipe.vae.enable_slicing() + pipe.enable_attention_slicing() + except: + pass + + return pipe + +def train_lora_step(pipe, dataloader, device, model_dtype): + """Perform a single memory-efficient LoRA update.""" + pipe.unet.train() + pipe.unet.requires_grad_(True) + + # Only optimize LoRA parameters + lora_params = [p for p in pipe.unet.parameters() if p.requires_grad] + optimizer = torch.optim.Adam(lora_params, lr=1e-5) + + total_loss = 0.0 + num_batches = 0 + for batch in dataloader: + images = batch["pixel_values"] + if isinstance(images, list): + images = torch.stack(images) + + images = images.to(device, dtype=model_dtype) + optimizer.zero_grad() + + # Encode images to latents (diffusion model training) + with torch.no_grad(): + latents = pipe.vae.encode(images).latent_dist.sample() * 0.18215 + + # Sample random timestep and noise + timestep = torch.randint(0, pipe.scheduler.config.num_train_timesteps, (images.size(0),), device=device).long() + noise = torch.randn_like(latents, dtype=model_dtype) + noisy_latents = pipe.scheduler.add_noise(latents, noise, timestep) + + # UNet forward pass (use empty text embeddings for unconditional training) + text_embeddings = torch.zeros(images.size(0), 77, 768, device=device, dtype=model_dtype) + noise_pred = pipe.unet(noisy_latents, timestep, encoder_hidden_states=text_embeddings).sample + + loss = nn.functional.mse_loss(noise_pred, noise) + loss.backward() + + torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0) + optimizer.step() + + total_loss += loss.item() + num_batches += 1 + + return total_loss / num_batches if num_batches > 0 else 0.0 + + +def evaluate_lora_step(pipe, dataloader, device, model_dtype): + """Evaluate LoRA adapters with minimal memory usage.""" + pipe.unet.eval() + total_loss = 0.0 + num_batches = 0 + + with torch.no_grad(): + for batch in dataloader: + pixel_values = batch["pixel_values"] + + # Handle both Tensor and list cases + if isinstance(pixel_values, torch.Tensor): + images = pixel_values.to(device, dtype=model_dtype) + elif isinstance(pixel_values, list): + images = torch.stack(pixel_values).to(device, dtype=model_dtype) + else: + raise TypeError(f"Unexpected batch type: {type(pixel_values)}") + + latents = pipe.vae.encode(images).latent_dist.sample() * 0.18215 + timestep = torch.randint(0, pipe.scheduler.config.num_train_timesteps, (images.size(0),), device=device).long() + noise = torch.randn_like(latents, dtype=model_dtype) + noisy_latents = pipe.scheduler.add_noise(latents, noise, timestep) + + text_emb = torch.zeros(images.size(0), 77, 768, device=device, dtype=model_dtype) + noise_pred = pipe.unet(noisy_latents, timestep, encoder_hidden_states=text_emb).sample + + loss = nn.functional.mse_loss(noise_pred, noise) + total_loss += loss.item() + num_batches += 1 + + return total_loss / num_batches if num_batches > 0 else 0.0 + +def generate_image(context, device, torch_dtype): + base_model = context.run_config["base-model"] + pipe = StableDiffusionPipeline.from_pretrained(base_model, torch_dtype=torch_dtype) + pipe = pipe.to(device) + + base_path = "final_lora_model" + current_dir = os.getcwd() + lora_path = os.path.join(current_dir, base_path) + pipe.unet = PeftModel.from_pretrained(pipe.unet, lora_path) + pipe = pipe.to(device) + + prompt = context.run_config["prompt"] + image = pipe( + prompt, + num_inference_steps=30, + guidance_scale=7.5 + ).images[0] + + image.save("federated_diffusion_sample.png") + print("Image generated and saved as 'federated_diffusion_sample.png'") + + diff --git a/examples/quickstart-diffusion/pyproject.toml b/examples/quickstart-diffusion/pyproject.toml new file mode 100644 index 000000000000..d0f33b7b8841 --- /dev/null +++ b/examples/quickstart-diffusion/pyproject.toml @@ -0,0 +1,54 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "diffusion_example" +version = "1.0.0" +description = "Federated Learning with Diffusion Model and Flower (Quickstart Example)" +license = "Apache-2.0" +authors = [ + { name = "The Flower Authors", email = "hello@flower.ai" }, + { name = "Aash Mohammad", email = "2810aash.edu@gmail.com" }, +] +dependencies = [ + "flwr[simulation]>=1.23.0", + "flwr-datasets>=0.5.0", + "torch==2.8.0", + "torchvision==0.23.0", + "accelerate==0.30.1", + "diffusers==0.28.2", + "transformers>=4.41.2,<5.0", + "peft==0.11.0", + "evaluate>=0.4.0,<1.0", + "huggingface_hub==0.23.0", +] + +[tool.hatch.build.targets.wheel] +packages = ["."] + +[tool.flwr.app] +publisher = "flwrlabs" + +[tool.flwr.app.components] +serverapp = "diffusion_example.server_app:app" +clientapp = "diffusion_example.client_app:app" + +[tool.flwr.app.config] +num-server-rounds = 3 +base-model = "runwayml/stable-diffusion-v1-5" +local-epochs = 1 +num-clients = 2 +fraction-train = 1.0 +fraction-evaluate = 1.0 + +[tool.flwr.federations] +default = "local-simulation" + +[tool.flwr.federations.local-simulation] +options.num-supernodes = 2 + +[tool.flwr.federations.local-simulation-gpu] +options.num-supernodes = 2 +options.backend.client-resources.num-cpus = 2 +options.backend.client-resources.num-gpus = 0.5