FedPM: Federated Learning Using Second-order Optimization with Preconditioned Mixing of Local Parameters
This directory contains the supplementary code implementation for the paper "FedPM: Federated Learning Using Second-order Optimization with Preconditioned Mixing of Local Parameters".
FedPM is a novel federated learning method that leverages second-order optimization with preconditioned mixing of local parameters on the server. This implementation includes experiments for both strongly convex models (Test 1) and non-convex deep learning models (Test 2).
- Python 3.8+
- PyTorch 2.1.0
- CUDA 11.8 (for GPU support)
- See
requirements.txtfor complete dependencies
- Install dependencies:
pip install -r requirements.txt- Set up data directory (optional):
export DATA_DIR=./data # Default data directoryThe main entry point is main.py. The code is designed to run sequentially for easy testing without requiring multiple GPUs.
To disable Weights & Biases logging, add the --disable_wandb flag to any command.
These experiments use logistic regression with L2 regularization on the w8a and a9a datasets from LibSVM.
python main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_loop {steps} --model logreg_fednl --dataset {a9a|w8a} --algorithm fedavg --logreg_lambda {lambda} --logreg_init_sigma {sigma} --batch_size {batch_size} --lr {learning_rate} --sequentialpython main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_loop {steps} --model logreg_fednl --dataset {a9a|w8a} --algorithm fedavgm --logreg_lambda {lambda} --logreg_init_sigma {sigma} --fedavgm_server_momentum {momentum} --batch_size {batch_size} --lr {learning_rate} --sequentialpython main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_loop {steps} --model logreg_fednl --dataset {a9a|w8a} --algorithm fedns --logreg_lambda {lambda} --logreg_init_sigma {sigma} --fedns_sketch_size {sketch_size} --batch_size {batch_size} --lr {learning_rate} --sequentialpython main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_loop {steps} --model logreg_fednl --dataset {a9a|w8a} --algorithm scaffold --logreg_lambda {lambda} --logreg_init_sigma {sigma} --batch_size {batch_size} --lr {learning_rate} --sequentialpython main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_loop {steps} --model logreg_fednl --dataset {a9a|w8a} --algorithm fedadam --logreg_lambda {lambda} --logreg_init_sigma {sigma} --fedadam_server_lr {server_lr} --batch_size {batch_size} --lr {learning_rate} --sequentialpython main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_loop {steps} --model logreg_fednl --dataset {a9a|w8a} --algorithm newton --logreg_lambda {lambda} --logreg_init_sigma {sigma} --mixing_override fedavg --batch_size {batch_size} --lr {learning_rate} --sequentialpython main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_loop {steps} --model logreg_fednl --dataset {a9a|w8a} --algorithm newton --logreg_lambda {lambda} --logreg_init_sigma {sigma} --batch_size {batch_size} --lr {learning_rate} --sequentialpython main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_loop {steps} --model logreg_fednl --dataset {a9a|w8a} --algorithm fednl --logreg_lambda {lambda} --logreg_init_sigma {sigma} --mixing_override fedavg --batch_size {batch_size} --lr {learning_rate} --sequentialThese experiments use CNN models on CIFAR10 and CIFAR100 datasets with heterogeneous data distribution.
python main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_epoch {epochs} --model {cnn_feddisco|resnet18_feddisco} --dataset {cifar10|cifar100} --data_dist_strategy random --dirichlet_alpha {alpha} --algorithm fedavg --batch_size {batch_size} --lr {learning_rate} --max_grad_norm {grad_norm} --weight_decay {weight_decay} --sequentialpython main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_epoch {epochs} --model {cnn_feddisco|resnet18_feddisco} --dataset {cifar10|cifar100} --data_dist_strategy random --dirichlet_alpha {alpha} --algorithm fedavgm --fedavgm_server_momentum {momentum} --batch_size {batch_size} --lr {learning_rate} --max_grad_norm {grad_norm} --weight_decay {weight_decay} --sequentialpython main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_epoch {epochs} --model {cnn_feddisco|resnet18_feddisco} --dataset {cifar10|cifar100} --data_dist_strategy random --dirichlet_alpha {alpha} --algorithm fedprox --fedprox_mu {mu} --batch_size {batch_size} --lr {learning_rate} --max_grad_norm {grad_norm} --sequentialpython main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_epoch {epochs} --model {cnn_feddisco|resnet18_feddisco} --dataset {cifar10|cifar100} --data_dist_strategy random --dirichlet_alpha {alpha} --algorithm scaffold --batch_size {batch_size} --lr {learning_rate} --max_grad_norm {grad_norm} --weight_decay {weight_decay} --sequentialpython main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_epoch {epochs} --model {cnn_feddisco|resnet18_feddisco} --dataset {cifar10|cifar100} --data_dist_strategy random --dirichlet_alpha {alpha} --algorithm fedadam --fedadam_server_lr {server_lr} --batch_size {batch_size} --lr {learning_rate} --max_grad_norm {grad_norm} --weight_decay {weight_decay} --sequentialpython main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_epoch {epochs} --model {cnn_feddisco|resnet18_feddisco} --dataset {cifar10|cifar100} --data_dist_strategy random --dirichlet_alpha {alpha} --algorithm regmean --regmean_precondition_damping {damping} --regmean_reset_covs --regmean_compute_on_whole_data --regmean_precondition_local_grad --mixing_override fedavg --batch_size {batch_size} --lr {learning_rate} --max_grad_norm {grad_norm} --weight_decay {weight_decay} --sequentialpython main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_epoch {epochs} --model {cnn_feddisco|resnet18_feddisco} --dataset {cifar10|cifar100} --data_dist_strategy random --dirichlet_alpha {alpha} --algorithm regmean --regmean_mixing_damping {mixing_damping} --regmean_precondition_damping {precond_damping} --regmean_reset_covs --regmean_compute_on_whole_data --regmean_cov_batch_size {cov_batch_size} --regmean_precondition_local_grad --batch_size {batch_size} --lr {learning_rate} --max_grad_norm {grad_norm} --weight_decay {weight_decay} --sequentialnum_clients: Total number of clients in the federationnum_participants: Number of clients participating in each roundcommunication_rounds: Number of communication roundsbatch_size: Local batch size for traininglr: Learning ratesequential: Enable sequential execution (no parallelization)
inner_loop: Number of gradient steps between communicationslogreg_lambda: L2 regularization coefficientlogreg_init_sigma: Standard deviation for initial parameter deviation from optimum
inner_epoch: Number of local epochs between communicationsdirichlet_alpha: Controls data heterogeneity (lower values = more heterogeneous)max_grad_norm: Maximum gradient norm for clipping (-1.0 disables clipping)weight_decay: Weight decay coefficient
# FedPM on w8a with 142 clients
python main.py --num_clients 142 --num_participants 142 --communication_rounds 50 --inner_loop 1 --model logreg_fednl --dataset w8a --algorithm newton --logreg_lambda 0.01 --logreg_init_sigma 0.1 --batch_size 350 --lr 0.1 --sequential --disable_wandb# FedPM on CIFAR10 with heterogeneous data
python main.py --num_clients 10 --num_participants 10 --communication_rounds 100 --inner_epoch 5 --model cnn_feddisco --dataset cifar10 --data_dist_strategy random --dirichlet_alpha 0.1 --algorithm regmean --regmean_mixing_damping 1.0 --regmean_precondition_damping 1.0 --regmean_reset_covs --regmean_compute_on_whole_data --regmean_cov_batch_size -1 --regmean_precondition_local_grad --batch_size 64 --lr 0.5 --max_grad_norm 1.0 --weight_decay 0.0001 --sequential --disable_wandb- The code automatically downloads CIFAR10 and CIFAR100 on first run
- Results are logged to console and optionally to Weights & Biases
- GPU is used automatically if available
- For reproducibility, use the
--seedparameter