Skip to content

Official implementation for the paper 'FedPM: Federated Learning Using Second-order Optimization with Preconditioned Mixing of Local Parameters' (AAAI 2026).

License

Notifications You must be signed in to change notification settings

rioyokotalab/FedPM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

FedPM: Federated Learning Using Second-order Optimization with Preconditioned Mixing of Local Parameters

This directory contains the supplementary code implementation for the paper "FedPM: Federated Learning Using Second-order Optimization with Preconditioned Mixing of Local Parameters".

Overview

FedPM is a novel federated learning method that leverages second-order optimization with preconditioned mixing of local parameters on the server. This implementation includes experiments for both strongly convex models (Test 1) and non-convex deep learning models (Test 2).

Requirements

  • Python 3.8+
  • PyTorch 2.1.0
  • CUDA 11.8 (for GPU support)
  • See requirements.txt for complete dependencies

Installation

  1. Install dependencies:
pip install -r requirements.txt
  1. Set up data directory (optional):
export DATA_DIR=./data  # Default data directory

Basic Usage

The main entry point is main.py. The code is designed to run sequentially for easy testing without requiring multiple GPUs.

To disable Weights & Biases logging, add the --disable_wandb flag to any command.

Test 1: Strongly Convex Models (Logistic Regression)

These experiments use logistic regression with L2 regularization on the w8a and a9a datasets from LibSVM.

FedAvg/PSGD

python main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_loop {steps} --model logreg_fednl --dataset {a9a|w8a} --algorithm fedavg --logreg_lambda {lambda} --logreg_init_sigma {sigma} --batch_size {batch_size} --lr {learning_rate} --sequential

FedAvgM

python main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_loop {steps} --model logreg_fednl --dataset {a9a|w8a} --algorithm fedavgm --logreg_lambda {lambda} --logreg_init_sigma {sigma} --fedavgm_server_momentum {momentum} --batch_size {batch_size} --lr {learning_rate} --sequential

FedNS

python main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_loop {steps} --model logreg_fednl --dataset {a9a|w8a} --algorithm fedns --logreg_lambda {lambda} --logreg_init_sigma {sigma} --fedns_sketch_size {sketch_size} --batch_size {batch_size} --lr {learning_rate} --sequential

SCAFFOLD

python main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_loop {steps} --model logreg_fednl --dataset {a9a|w8a} --algorithm scaffold --logreg_lambda {lambda} --logreg_init_sigma {sigma} --batch_size {batch_size} --lr {learning_rate} --sequential

FedAdam

python main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_loop {steps} --model logreg_fednl --dataset {a9a|w8a} --algorithm fedadam --logreg_lambda {lambda} --logreg_init_sigma {sigma} --fedadam_server_lr {server_lr} --batch_size {batch_size} --lr {learning_rate} --sequential

LocalNewton

python main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_loop {steps} --model logreg_fednl --dataset {a9a|w8a} --algorithm newton --logreg_lambda {lambda} --logreg_init_sigma {sigma} --mixing_override fedavg --batch_size {batch_size} --lr {learning_rate} --sequential

FedPM (Proposed Method)

python main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_loop {steps} --model logreg_fednl --dataset {a9a|w8a} --algorithm newton --logreg_lambda {lambda} --logreg_init_sigma {sigma} --batch_size {batch_size} --lr {learning_rate} --sequential

FedNL

python main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_loop {steps} --model logreg_fednl --dataset {a9a|w8a} --algorithm fednl --logreg_lambda {lambda} --logreg_init_sigma {sigma} --mixing_override fedavg --batch_size {batch_size} --lr {learning_rate} --sequential

Test 2: Non-convex Deep Learning Models

These experiments use CNN models on CIFAR10 and CIFAR100 datasets with heterogeneous data distribution.

FedAvg

python main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_epoch {epochs} --model {cnn_feddisco|resnet18_feddisco} --dataset {cifar10|cifar100} --data_dist_strategy random --dirichlet_alpha {alpha} --algorithm fedavg --batch_size {batch_size} --lr {learning_rate} --max_grad_norm {grad_norm} --weight_decay {weight_decay} --sequential

FedAvgM

python main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_epoch {epochs} --model {cnn_feddisco|resnet18_feddisco} --dataset {cifar10|cifar100} --data_dist_strategy random --dirichlet_alpha {alpha} --algorithm fedavgm --fedavgm_server_momentum {momentum} --batch_size {batch_size} --lr {learning_rate} --max_grad_norm {grad_norm} --weight_decay {weight_decay} --sequential

FedProx

python main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_epoch {epochs} --model {cnn_feddisco|resnet18_feddisco} --dataset {cifar10|cifar100} --data_dist_strategy random --dirichlet_alpha {alpha} --algorithm fedprox --fedprox_mu {mu} --batch_size {batch_size} --lr {learning_rate} --max_grad_norm {grad_norm} --sequential

SCAFFOLD

python main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_epoch {epochs} --model {cnn_feddisco|resnet18_feddisco} --dataset {cifar10|cifar100} --data_dist_strategy random --dirichlet_alpha {alpha} --algorithm scaffold --batch_size {batch_size} --lr {learning_rate} --max_grad_norm {grad_norm} --weight_decay {weight_decay} --sequential

FedAdam

python main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_epoch {epochs} --model {cnn_feddisco|resnet18_feddisco} --dataset {cifar10|cifar100} --data_dist_strategy random --dirichlet_alpha {alpha} --algorithm fedadam --fedadam_server_lr {server_lr} --batch_size {batch_size} --lr {learning_rate} --max_grad_norm {grad_norm} --weight_decay {weight_decay} --sequential

LocalNewton

python main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_epoch {epochs} --model {cnn_feddisco|resnet18_feddisco} --dataset {cifar10|cifar100} --data_dist_strategy random --dirichlet_alpha {alpha} --algorithm regmean --regmean_precondition_damping {damping} --regmean_reset_covs --regmean_compute_on_whole_data --regmean_precondition_local_grad --mixing_override fedavg --batch_size {batch_size} --lr {learning_rate} --max_grad_norm {grad_norm} --weight_decay {weight_decay} --sequential

FedPM (Proposed Method)

python main.py --num_clients {num_clients} --num_participants {num_participants} --communication_rounds {rounds} --inner_epoch {epochs} --model {cnn_feddisco|resnet18_feddisco} --dataset {cifar10|cifar100} --data_dist_strategy random --dirichlet_alpha {alpha} --algorithm regmean --regmean_mixing_damping {mixing_damping} --regmean_precondition_damping {precond_damping} --regmean_reset_covs --regmean_compute_on_whole_data --regmean_cov_batch_size {cov_batch_size} --regmean_precondition_local_grad --batch_size {batch_size} --lr {learning_rate} --max_grad_norm {grad_norm} --weight_decay {weight_decay} --sequential

Parameter Guidelines

Common Parameters

  • num_clients: Total number of clients in the federation
  • num_participants: Number of clients participating in each round
  • communication_rounds: Number of communication rounds
  • batch_size: Local batch size for training
  • lr: Learning rate
  • sequential: Enable sequential execution (no parallelization)

Test 1 Specific

  • inner_loop: Number of gradient steps between communications
  • logreg_lambda: L2 regularization coefficient
  • logreg_init_sigma: Standard deviation for initial parameter deviation from optimum

Test 2 Specific

  • inner_epoch: Number of local epochs between communications
  • dirichlet_alpha: Controls data heterogeneity (lower values = more heterogeneous)
  • max_grad_norm: Maximum gradient norm for clipping (-1.0 disables clipping)
  • weight_decay: Weight decay coefficient

Example Commands

Test 1 Example (w8a dataset)

# FedPM on w8a with 142 clients
python main.py --num_clients 142 --num_participants 142 --communication_rounds 50 --inner_loop 1 --model logreg_fednl --dataset w8a --algorithm newton --logreg_lambda 0.01 --logreg_init_sigma 0.1 --batch_size 350 --lr 0.1 --sequential --disable_wandb

Test 2 Example (CIFAR10)

# FedPM on CIFAR10 with heterogeneous data
python main.py --num_clients 10 --num_participants 10 --communication_rounds 100 --inner_epoch 5 --model cnn_feddisco --dataset cifar10 --data_dist_strategy random --dirichlet_alpha 0.1 --algorithm regmean --regmean_mixing_damping 1.0 --regmean_precondition_damping 1.0 --regmean_reset_covs --regmean_compute_on_whole_data --regmean_cov_batch_size -1 --regmean_precondition_local_grad --batch_size 64 --lr 0.5 --max_grad_norm 1.0 --weight_decay 0.0001 --sequential --disable_wandb

Notes

  • The code automatically downloads CIFAR10 and CIFAR100 on first run
  • Results are logged to console and optionally to Weights & Biases
  • GPU is used automatically if available
  • For reproducibility, use the --seed parameter

About

Official implementation for the paper 'FedPM: Federated Learning Using Second-order Optimization with Preconditioned Mixing of Local Parameters' (AAAI 2026).

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages