A comprehensive implementation of federated learning for MNIST digit classification, featuring robust Byzantine attack detection and mitigation.
- Project Overview
- Requirements
- Environment Setup
- Project Structure
- Running the Solutions
- Implementation Details
- Results Summary
This project implements three levels of machine learning systems for MNIST classification:
- Traditional centralized training approach
- Single model trained on entire dataset
- Baseline for comparison
- Distributed training across 10 clients
- Privacy-preserving collaborative learning
- FedAvg aggregation algorithm
- Byzantine attack simulation (Model Poisoning)
- Pairwise cosine similarity detection
- Malicious client exclusion mechanism
- Maintains accuracy despite attacks
- Python 3.8 or higher
- 4GB RAM minimum (8GB recommended)
All required packages are listed in requirements.txt:
- PyTorch 2.9.0+
- TorchVision 0.24.0+
- NumPy 2.2.4+
- Matplotlib 3.10.1+ (optional, for visualization)
-
Create a new conda environment:
conda create -n datasec_env python=3.13
-
Activate the environment:
conda activate datasec_env
-
Install required packages:
pip install -r requirements.txt
-
Create a virtual environment:
python -m venv datasec_env
-
Activate the environment:
- Windows:
datasec_env\Scripts\activate
- Linux/Mac:
source datasec_env/bin/activate
- Windows:
-
Install required packages:
pip install -r requirements.txt
If you encounter issues with the requirements file, install packages individually:
pip install torch torchvision numpy matplotlibRun the test script to verify your environment:
python test.pyExpected output: Training for 3 epochs, final accuracy ~98%
CSC8370_Federated_learning_Project/
├── README.md # This file
├── requirements.txt # Essential packages only
│
├── test.py # Environment verification script
├── best_model.pth # Saved model from test.py
│
├── Templates/ # Solution implementations
│ ├── Level_1_Solution.py # Centralized Learning
│ ├── Level_2_solution_1.py # Federated Learning
│ ├── Level_3_Solution.py # Robust FL with Attack Detection
│ ├── dataloader4level1.py # (Reference)
│ └── a federated learning framework.py # (Documentation)
│
└── data/ # Auto-downloaded MNIST dataset
└── mnist/
Ensure your environment is activated:
conda activate datasec_env # or your environment nameObjective: Train a CNN on the full MNIST dataset in a centralized manner.
Command:
python Templates/Level_1_Solution.pyWhat it does:
- Loads entire MNIST dataset (60,000 training images)
- Trains CNN for 3 epochs
- Evaluates on 10,000 test images
- Saves best model as
best_model.pth
Expected Output:
The number of training data: 60000
The number of testing data: 10000
training has:1200 batch of data!
testing has:1 batch of data!
epoch:1,index of train:100,loss: 1.539659,acc:55.30%
...
epoch:3,index of train:1200,loss: 0.047619,acc:98.48%
Best model saved with accuracy: 0.9849
Accuracy: 0.9849
Runtime: ~2-5 minutes (CPU), ~1-2 minutes (GPU)
Expected Accuracy: 97-99%
Objective: Implement federated learning with 10 clients.
Command:
python Templates/Level_2_solution_1.pyWhat it does:
- Splits MNIST dataset among 10 clients (6,000 images each)
- Each client trains locally for 2 epochs
- Server aggregates models using FedAvg
- Repeats for 10 global rounds
- Saves final model as
federated_model.pth
Expected Output:
Global Epoch 1/10
Global Model Test Accuracy after round 1: 0.9039
Global Epoch 2/10
Global Model Test Accuracy after round 2: 0.9470
...
Global Epoch 10/10
Global Model Test Accuracy after round 10: 0.9851
Federated learning process completed.
Runtime: ~5-10 minutes (CPU), ~2-4 minutes (GPU)
Expected Accuracy: 97-99%
Key Parameters (configurable at line 109):
federated_learning(
n_clients=10, # Number of clients
global_epochs=10, # Global rounds
local_epochs=2 # Local training epochs per round
)Objective: Detect and mitigate Byzantine attacks in federated learning.
Command:
python Templates/Level_3_Solution.pyWhat it does:
- Simulates federated learning with 10 clients
- Client 5 becomes malicious at Round 5
- Malicious client injects random noise (model poisoning)
- Detection mechanism uses pairwise cosine similarity
- Excludes detected malicious clients from aggregation
- Maintains model accuracy despite ongoing attack
- Saves final model as
robust_federated_model.pth
Expected Output:
======================================================================
ROBUST FEDERATED LEARNING WITH ATTACK DETECTION
======================================================================
Configuration:
- Clients: 10
- Global Epochs: 10
- Local Epochs: 2
- Malicious Client: 5
- Attack Starts: Round 5
======================================================================
Global Epoch 1/10
Running malicious client detection...
Aggregating 10 benign clients (excluded 0 malicious)
Global Model Test Accuracy: 0.8988
======================================================================
...
======================================================================
Global Epoch 5/10
Client 5: MALICIOUS (injecting false updates)
Running malicious client detection...
Client 5 flagged as malicious (avg similarity: 0.0095)
Aggregating 9 benign clients (excluded 1 malicious)
Global Model Test Accuracy: 0.9753
Malicious client 5 successfully detected!
======================================================================
...
ROBUST FEDERATED LEARNING COMPLETED
Final Accuracy: 0.9851
======================================================================
Runtime: ~5-10 minutes (CPU), ~2-4 minutes (GPU)
Expected Accuracy: 97-99% (maintained despite attack!)
Key Parameters (configurable at lines 244-248):
robust_federated_learning(
n_clients=10, # Number of clients
global_epochs=10, # Global rounds (required: 10)
local_epochs=2, # Local training epochs
malicious_client_id=5, # Which client is malicious
attack_start_round=5 # When attack begins
)Input: 28x28 grayscale image
Conv1: 1→6 channels, 5×5 kernel → ReLU → MaxPool(2×2)
Conv2: 6→16 channels, 5×5 kernel → ReLU → MaxPool(2×2)
Flatten: 16×4×4 = 256
FC1: 256→120 → ReLU
FC2: 120→84 → ReLU
FC3: 84→10 (output classes)For each global round:
1. Distribute global model to all clients
2. Each client trains locally on their data
3. Clients send updated models to server
4. Server averages all client models:
global_params = mean(client_params[0], ..., client_params[9])
5. RepeatPairwise Cosine Similarity Detection:
1. For each client, compute similarity with all other clients
2. Calculate average similarity for each client
3. Benign clients: avg_similarity > 0.95 (very similar)
4. Malicious client: avg_similarity < 0.1 (very different)
5. Threshold: 0.8 (flag clients below this)
6. Exclude flagged clients from aggregation- Category: Model Poisoning / Byzantine Attack
- Method: Random Gaussian noise injection (magnitude: 10.0)
- Target: Model parameters directly
- Impact: Without detection, accuracy drops to ~10-20%
- With detection: Accuracy maintained at ~98%
| Metric | Level 1 | Level 2 | Level 3 |
|---|---|---|---|
| Approach | Centralized | Federated (10 clients) | Robust FL |
| Training Mode | Single server | Distributed | Distributed + Defense |
| Epochs | 3 | 10 global × 2 local | 10 global × 2 local |
| Attack Present | No | No | Yes (Round 5+) |
| Detection Active | N/A | N/A | Yes |
| Final Accuracy | 98.49% | 98.51% | 98.51% |
| Detection Rate | N/A | N/A | 100% (6/6 rounds) |
| Runtime (CPU) | ~3 min | ~7 min | ~8 min |
- Federated learning achieves comparable accuracy to centralized learning
- Byzantine attack successfully detected in all attack rounds
- Model accuracy maintained at 98.51% despite ongoing attack
- Zero false positives - only malicious client flagged
- System remains functional with 9/10 clients after exclusion
Solution: Install PyTorch:
pip install torch torchvisionSolution: The code automatically falls back to CPU. If you want to force CPU:
device = torch.device('cpu') # Line 58/69/167 in solution filesSolution: Reduce batch size or number of workers:
batch_size = 25 # Reduce from 50Solution: Download MNIST manually from http://yann.lecun.com/exdb/mnist/ and place in ./data/mnist/
Solution: Normal behavior. CPU training takes 2-3× longer than GPU. Be patient!
| Parameter | Location | Default | Description |
|---|---|---|---|
learning_rate |
Level 1: line 62 | 0.0005 | Optimizer learning rate |
batch_size |
Level 1: line 14 | 50 | Training batch size |
epoches |
Level 1: line 63 | 3 | Training epochs |
n_clients |
Level 2: line 109 | 10 | Number of federated clients |
global_epochs |
Level 2: line 109 | 10 | Federated rounds |
local_epochs |
Level 2: line 109 | 2 | Client training epochs |
threshold |
Level 3: line 220 | 0.8 | Detection threshold |
malicious_client_id |
Level 3: line 247 | 5 | Which client attacks |
attack_start_round |
Level 3: line 248 | 5 | When attack begins |
- Original FedAvg Paper: McMahan et al., 2017
- https://arxiv.org/abs/1602.05629
- Survey: Lyu et al., 2020
- https://arxiv.org/abs/2007.10747
- Blanchard et al., "Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent"
- https://arxiv.org/abs/1703.02757
Course: CSC8370 - Data Security Institution: Georgia State University Semester: 2025
Project Template: Provided by TA Dong Yang Implementation: Rohit Arodi Ramachandra (002830329) & Ashish Reddy Mandadi (002850578)
Email : rarodiramachandra1@student.gsu.edu
Last Updated: November 2025 Version: 1.0