-
Notifications
You must be signed in to change notification settings - Fork 315
Implement dpo tuning #417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Implement dpo tuning #417
Conversation
- Added DPO loss function in - Introduced in - Extended trainer with DPO training step and dual forward passes - Created with main training logic (mirroring LoRA structure) - Added CLI entry point in and integrated DPO subcommand in - Implemented unit and integration tests in - Updated documentation with data formats, configs, metrics, and guidelines - Ensured compliance with MLX-LM contributing standards (formatting, testing, pre-commit)
- Implement DPO loss function with Bradley-Terry preference model - Add PreferenceDataset for handling chosen/rejected response pairs - Create DPO training pipeline with dual forward passes - Add CLI interface: Loading pretrained model - Support LoRA/DoRA/full fine-tuning with DPO - Include comprehensive test suite and documentation - Compatible with existing MLX-LM infrastructure
- Add --reference-adapter-path parameter to DPO CLI - Enable using locally finetuned LoRA adapters as reference models - Leverages existing load() function's adapter_path parameter - Allows reusing MLX finetuned models for subsequent DPO training Usage: mlx_lm.dpo --model base_model --reference-model base_model --reference-adapter-path path/to/adapters --data data.jsonl --train
- Change --reference-adapter-path to --reference-model-adapters - Resolves parameter name confusion with adapter_path during config merge - Ensures reference model adapters are loaded correctly for DPO training - Update CONFIG_DEFAULTS and function references accordingly The original parameter name was too similar to adapter_path and caused the config loading logic to overwrite the reference adapter path with the current training adapter path.
…optimization - Add preference accuracy and reward margin tracking to DPO evaluation - Implement memory-efficient shared weights when no reference model specified - Update validation and test reporting to include accuracy metrics
…compiled step function
|
Hi maintainers! @awni @Goekdeniz-Guelmez This PR adds DPO support for preference-based fine-tuning. It's been open for ~2 months and I'd appreciate any feedback. Happy to address any concerns or make changes. Thanks! |
|
Hey Mohammad, merging will take time, since this is a big addition, you'd have to wait so that it passes all the reviews and has the needed code quality. |
|
Dear MLX teams. Some users(including me) have been awaited adoption of DPO into MLX framework. We'll highly appriciate if you consider merging of ths PR. |
Add Direct Preference Optimization (DPO) Support
Summary
This PR implements Direct Preference Optimization (DPO) for MLX-LM, enabling users to fine-tune language models using human preference data without requiring a separate reward model.
What is DPO?
DPO is a simpler alternative to RLHF that directly optimizes on preference pairs (chosen vs rejected responses), avoiding the complexity of training reward models and using PPO. It's mathematically equivalent to RLHF but more stable and efficient.
Key Features Added
Core Implementation
Integration
python -m mlx_lm dpowith full argument parsingData Format Support
{"prompt": "What is AI?", "chosen": "Detailed explanation...", "rejected": "Short answer."} {"messages": [{"role": "user", "content": "Hello"}], "chosen": "Detailed response", "rejected": "Hi."}Documentation & Testing
mlx_lm/DPO.mdwith usage examples and best practicesUsage Example
# Basic DPO training mlx_lm.dpo \ --model mlx-community/Meta-Llama-3-8B-Instruct-4bit \ --train \ --data preference_data/ \ --beta 0.1 \ --fine-tune-type loraFiles Added/Modified
mlx_lm/dpo.py- Main DPO modulemlx_lm/tuner/losses.py- DPO loss implementationmlx_lm/tuner/datasets.py- PreferenceDataset classmlx_lm/tuner/trainer.py- DPO training functionsmlx_lm/__main__.py- CLI registrationmlx_lm/DPO.md- Documentationtests/test_dpo.py- DPO-specific teststests/test_losses.py- Enhanced with DPO loss testsTesting
Benefits for MLX-LM Users
This implementation enables MLX-LM users to easily train more helpful, harmless, and honest models using preference data.