-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Add diffusion model training example using federated learning with Lo… #6071
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
I have tested this code on my local system, and it runs successfully on the CPU using a small subset of the MNIST dataset. However, I was unable to test it on a GPU due to hardware limitations on my Mac system. |
|
Hey @aash-mohammad, thanks for opening a PR for a new example🚀! We'll take a look in the coming days. What would be a good dataset to use other than MNIST? How about https://huggingface.co/datasets/nkirschi/oxford-flowers 😉 ? |
|
Hi @jafermarq For the initial version, I used MNIST mainly to ensure lightweight training and easier reproducibility on CPU environments. Once the example is finalized, I can extend support to the Oxford Flowers dataset to make the example more comprehensive. |
jafermarq
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @aash-mohammad , please find my first wave of feedback. I was able to run the example using GPU after adding a small tweak (ensuring both model and inputs have the same dtype -- i left a couple of comments about it later). But better double check the documentation of diffusers in case there's a better way of doing this.
I ran the example for 20 rounds but i only see the validation loss decreasing marginally and the train loss always outputs nan. Did you observe the same behaviour?
Finally, could you also include a README.md in the same directory where the pyproject.toml is ? Please follow the style in other examples readmes, like the one in quickstart-pytorch. We can start with a short readme (just like in the other quickstart- examples) and expand it later. I think it would be amazing under the first paragraph describing what the example does you could include a minimal diagram and resulting images from the simulation (as other people would see when running it on their own). For example, as we did with our LeRobot example: https://github.com/adap/flower/tree/main/examples/quickstart-lerobot
|
Hi @aash-mohammad , thanks for going through the review comments i left. Please ping me when you commit your changes. Then, i can do the final review 🙌 |
|
Hey @jafermarq Regarding your question: Possible reasons:
Updates:
Let me know if further adjustments are needed! |
692f9d5 to
898f874
Compare
|
Hi @jafermarq |
jafermarq
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @aash-mohammad , sorry for the delay. I tried running the example simply doing:
flwr run . local-simulation-gpubut it crashes with error RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half. Could you take a look and ping us when it is resolved?
Also, ideally the example would be complete in the sense that it shows how a person doing the fine-tuning can later generate an image using the resulting model. This part is not explained.
898f874 to
91beced
Compare
|
Hi @jafermarq flwr run . local-simulation-gpu I’m no longer encountering the error you were seeing earlier. I’ve also added a script that automatically generates images once the model fine-tuning is completed. |
Federated Diffusion Model Training Example
This example demonstrates how to train a Diffusion Model in a Federated Learning (FL) environment using Flower
and Hugging Face Diffusers.
The goal of this example is to show how privacy-preserving generative models can be collaboratively trained across distributed clients without sharing raw data. Each client trains locally on a partition of the MNIST dataset, and the global model is updated via federated averaging (FedAvg).
To optimize training efficiency, the implementation leverages LoRA (Low-Rank Adaptation) for parameter-efficient fine-tuning, drastically reducing memory and computation requirements. This setup enables diffusion model training even on systems with limited GPU resources.
#6045