Skip to content

Conversation

@aash-mohammad
Copy link

@aash-mohammad aash-mohammad commented Nov 1, 2025

Federated Diffusion Model Training Example

This example demonstrates how to train a Diffusion Model in a Federated Learning (FL) environment using Flower
and Hugging Face Diffusers.

The goal of this example is to show how privacy-preserving generative models can be collaboratively trained across distributed clients without sharing raw data. Each client trains locally on a partition of the MNIST dataset, and the global model is updated via federated averaging (FedAvg).

To optimize training efficiency, the implementation leverages LoRA (Low-Rank Adaptation) for parameter-efficient fine-tuning, drastically reducing memory and computation requirements. This setup enables diffusion model training even on systems with limited GPU resources.

#6045

@aash-mohammad
Copy link
Author

I have tested this code on my local system, and it runs successfully on the CPU using a small subset of the MNIST dataset. However, I was unable to test it on a GPU due to hardware limitations on my Mac system.

@github-actions github-actions bot added the Contributor Used to determine what PRs (mainly) come from external contributors. label Nov 1, 2025
@jafermarq
Copy link
Member

Hey @aash-mohammad, thanks for opening a PR for a new example🚀! We'll take a look in the coming days. What would be a good dataset to use other than MNIST? How about https://huggingface.co/datasets/nkirschi/oxford-flowers 😉 ?

@aash-mohammad
Copy link
Author

Hi @jafermarq
Thank you for the response!
Yes, using the Oxford Flowers dataset
would be a great idea — it’s well-suited for diffusion-based image generation tasks and can better demonstrate the model’s generative capabilities compared to MNIST.

For the initial version, I used MNIST mainly to ensure lightweight training and easier reproducibility on CPU environments. Once the example is finalized, I can extend support to the Oxford Flowers dataset to make the example more comprehensive.

Copy link
Member

@jafermarq jafermarq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @aash-mohammad , please find my first wave of feedback. I was able to run the example using GPU after adding a small tweak (ensuring both model and inputs have the same dtype -- i left a couple of comments about it later). But better double check the documentation of diffusers in case there's a better way of doing this.

I ran the example for 20 rounds but i only see the validation loss decreasing marginally and the train loss always outputs nan. Did you observe the same behaviour?

Finally, could you also include a README.md in the same directory where the pyproject.toml is ? Please follow the style in other examples readmes, like the one in quickstart-pytorch. We can start with a short readme (just like in the other quickstart- examples) and expand it later. I think it would be amazing under the first paragraph describing what the example does you could include a minimal diagram and resulting images from the simulation (as other people would see when running it on their own). For example, as we did with our LeRobot example: https://github.com/adap/flower/tree/main/examples/quickstart-lerobot

@jafermarq
Copy link
Member

Hi @aash-mohammad , thanks for going through the review comments i left. Please ping me when you commit your changes. Then, i can do the final review 🙌

@aash-mohammad
Copy link
Author

aash-mohammad commented Nov 18, 2025

Hey @jafermarq
Thanks for your valuable comments and feedback. I’ve addressed all the mentioned points.

Regarding your question:
“I ran the example for 20 rounds but only see the validation loss decreasing marginally, and the train loss always outputs nan. Did you observe the same behavior?”

Possible reasons:

  1. When using a larger number of samples for training the diffusion model, the client computation takes longer, and sometimes the node fails to respond within the expected time. If the response times out, those steps are skipped, which may result in missing or nan training loss values.
  2. To mitigate this, we can either:
  • Reduce the dataset size per client, or
  • Increase the client response timeout to allow sufficient time for training.

Updates:

  • Replaced the MNIST dataset with the Oxford Flowers dataset for better alignment with image generation tasks.
  • README.md is already located in the same directory as the pyproject.toml.

Let me know if further adjustments are needed!

@aash-mohammad aash-mohammad force-pushed the aash/diffusion_example branch 2 times, most recently from 692f9d5 to 898f874 Compare November 18, 2025 18:00
@aash-mohammad
Copy link
Author

Hi @jafermarq
Please let me know when you’ve completed the final review of my code. Once that’s done, I’ll proceed with raising the PR to add the input- and output-based privacy layer.

Copy link
Member

@jafermarq jafermarq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @aash-mohammad , sorry for the delay. I tried running the example simply doing:

flwr run . local-simulation-gpu

but it crashes with error RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half. Could you take a look and ping us when it is resolved?

Also, ideally the example would be complete in the sense that it shows how a person doing the fine-tuning can later generate an image using the resulting model. This part is not explained.

@aash-mohammad
Copy link
Author

Hi @jafermarq
I’ve made the changes so that it now runs using the following command:

flwr run . local-simulation-gpu

I’m no longer encountering the error you were seeing earlier. I’ve also added a script that automatically generates images once the model fine-tuning is completed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Contributor Used to determine what PRs (mainly) come from external contributors.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants