-
Notifications
You must be signed in to change notification settings - Fork 841
[WIP]Merge ulysses and ring-attention #5522
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
# Conflicts: # swift/trainers/trainers.py
# Conflicts: # swift/trainers/trainers.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @tastelikefeet, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a significant refactoring and enhancement of the sequence parallelism framework. It unifies the previously separate Ulysses and Ring Attention implementations into a single, more flexible SequenceParallel
class. The changes enable advanced distributed training strategies, including a new 'ring parallel' dimension and the integration of 'zigzag ring flash attention'. This consolidation simplifies the API for trainers and improves the overall efficiency and scalability of training large language models by optimizing how data and computations are distributed across devices.
Highlights
- Unified Sequence Parallelism: The separate
Ulysses
andRingAttention
sequence parallel implementations have been unified into a single, more robustSequenceParallel
class withinulysses.py
. This streamlines the codebase and provides a single entry point for sequence parallelism configurations. - Introduction of Ring Parallelism (RP): A new dimension of parallelism, 'Ring Parallelism' (RP), has been introduced, allowing for more complex and efficient distribution strategies, particularly for very long sequences. This is reflected in the updated device mesh creation and data handling methods.
- Advanced Attention Mechanism Integration: Specific implementations of 'Zigzag Ring Flash Attention' for both standard and variable-length sequences have been integrated, suggesting an advanced approach to attention computation in a distributed setting.
- Refactored Sequence Parallel API: The API for interacting with sequence parallelism has been simplified. Instead of multiple
init_sequence_parallel
andprepare_model
calls, a singlesequence_parallel.prepare
method is now used, making setup more consistent across different training components. - Comprehensive Trainer Adaptations: Various trainers (RLHF, SFT, DPO, GRPO) and core trainer mixins have been updated to seamlessly integrate with the new unified
SequenceParallel
class. This includes adaptations for data loading, loss computation, input preparation, and gradient accumulation, ensuring compatibility and leveraging the new parallelism features. - Codebase Simplification: Redundant files like
swift/trainers/sequence_parallel/base.py
andswift/trainers/sequence_parallel/ring_attention.py
have been removed, reducing code duplication and simplifying the project structure.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant refactoring by merging the ulysses
and ring-attention
sequence parallel implementations into a unified SequenceParallel
class. This is a great improvement for maintainability. The changes also move sequence parallel-specific logic from a shared utility file into the respective trainer classes, which improves modularity.
I've found a couple of issues that need to be addressed:
- A critical bug in
swift/trainers/mixin.py
whereget_sp_dataloader
is called with incorrect arguments. - A code duplication issue in
swift/trainers/utils.py
where a function is defined twice.
Please see my detailed comments below. Once these issues are resolved, this PR will be in a much better shape.
if eval_dataset is None and self.eval_dataset is None: | ||
raise ValueError('Trainer: evaluation requires an eval_dataset.') | ||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset | ||
dataloader = sequence_parallel.get_dataloader(self, eval_dataset, self.args.eval_batch_size) | ||
dataloader = self.get_sp_dataloader(self, eval_dataset, self.args.eval_batch_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The call to get_sp_dataloader
has incorrect arguments. The method's signature is get_sp_dataloader(self, dataset, batch_size, skip_batches=0)
, but it's being called with self
as the first explicit argument, which is then passed as the dataset
. This will lead to a runtime error.
The correct call should not pass self
explicitly.
dataloader = self.get_sp_dataloader(self, eval_dataset, self.args.eval_batch_size) | |
dataloader = self.get_sp_dataloader(eval_dataset, self.args.eval_batch_size) |
def per_token_loss_func(outputs, labels, enable_dft_loss: bool = False, **kwargs): | ||
"""Calculate the per token CE | ||
|
||
Args: | ||
outputs: The model outputs containing `logits` | ||
labels: The labels | ||
enable_dft_loss: Enable dft loss | ||
Returns: | ||
A tensor after cross_entropy | ||
""" | ||
|
||
logits = outputs.logits | ||
# Upcast to float if we need to compute the loss to avoid potential precision issues | ||
logits = logits.float() | ||
labels = torch.roll(labels, shifts=-1, dims=-1).view(-1) | ||
|
||
# Flatten the tokens | ||
logits = logits.view(-1, logits.shape[-1]) | ||
# Enable model parallelism | ||
labels = labels.to(logits.device) | ||
loss = F.cross_entropy(logits, labels, ignore_index=-100, reduction='none') | ||
if enable_dft_loss: | ||
with torch.no_grad(): | ||
target_probs = torch.exp(-loss) | ||
loss *= target_probs | ||
return loss | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR type
PR information
Write the detail information belongs to this PR.
Experiment results
Paste your experiment result here(if needed).