Skip to content

Commit 9ca9ef2

Browse files
authored
Merge branch 'main' into get-submodels
2 parents 0b6abe1 + bdee088 commit 9ca9ef2

File tree

96 files changed

+2860
-1228
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

96 files changed

+2860
-1228
lines changed

docs/source/en/deepspeed.md

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,108 @@ The example ZeRO-3 and ZeRO-Infinity config below sets most of the parameter val
368368
}
369369
```
370370
371+
### Sequence Parallelism
372+
373+
DeepSpeed's ALST/Ulysses sequence parallelism enables training with very long sequences by splitting the sequence across multiple GPUs. This is particularly useful for training large language models with very long sequence lengths.
374+
375+
Arctic Long Sequence Training (ALST) uses a combination of sharding inputs along the sequence dimension and attention head parallelism. With this approach, you can train models with sequence lengths up to 500K tokens on a single H100 GPU, 3.7M on a single node, or 15M tokens on just four nodes with Llama-8B. The implementation described here enables one component of the full ALST system. For additional optimizations like TiledMLP and activation checkpoint offloading, refer to the [DeepSpeed ALST tutorial](https://www.deepspeed.ai/tutorials/ulysses-alst-sequence-parallelism/).
376+
377+
> [!TIP]
378+
> For more detailed information about sequence parallelism, see the Accelerate [Sequence Parallelism](https://huggingface.co/docs/accelerate/concept_guides/sequence_parallelism) guide.
379+
380+
To enable ALST/Ulysses sequence parallelism with [`Trainer`], configure `parallelism_config` in [`TrainingArguments`]. Sequence parallelism is configured via Accelerate's `ParallelismConfig` and requires an Accelerate version higher than 1.12.0.
381+
382+
```py
383+
from accelerate.utils import ParallelismConfig, DeepSpeedSequenceParallelConfig
384+
385+
# Example: 4 GPUs with sp_size=4, dp_replicate_size=1 (no data parallelism)
386+
# Ensure total_size = dp_replicate_size * dp_shard_size * sp_size = 1 * 1 * 4 = 4 GPUs
387+
parallelism_config = ParallelismConfig(
388+
sp_backend="deepspeed",
389+
sp_size=4, # Number of GPUs to split sequence across
390+
dp_replicate_size=1, # Explicit: no data parallelism
391+
sp_handler=DeepSpeedSequenceParallelConfig(
392+
sp_seq_length_is_variable=True,
393+
sp_attn_implementation="sdpa",
394+
),
395+
)
396+
397+
training_args = TrainingArguments(
398+
...,
399+
deepspeed="path/to/deepspeed_config.json",
400+
parallelism_config=parallelism_config,
401+
)
402+
```
403+
404+
You can also configure sequence parallelism using an Accelerate config file.
405+
406+
```yaml
407+
distributed_type: DEEPSPEED
408+
deepspeed_config:
409+
deepspeed_config_file: path/to/ds_config.json
410+
machine_rank: 0
411+
num_machines: 1
412+
num_processes: 4 # Total number of processes
413+
parallelism_config:
414+
parallelism_config_sp_size: 4 # Sequence parallel size
415+
parallelism_config_dp_replicate_size: 1 # Must be: dp_replicate_size * dp_shard_size * sp_size = num_processes
416+
parallelism_config_sp_backend: deepspeed
417+
parallelism_config_sp_seq_length_is_variable: true
418+
parallelism_config_sp_attn_implementation: sdpa
419+
```
420+
421+
Important configuration parameters include the following.
422+
423+
* `sp_backend` must be set to `"deepspeed"` to use ALST/Ulysses sequence parallelism.
424+
* `sp_size` is the degree of sequence parallelism. For example, `sp_size=4` means 4 GPUs will process a single sequence in parallel. You need at least 2 GPUs to enable sequence parallelism. **Data feeding**: Each rank receives a unique data stream from the DataLoader (like DP). **Batch size calculation**: The effective `dp_world_size = world_size / sp_size`. So with 4 GPUs and `sp_size=4`, each of the 4 ranks gets different samples from the DataLoader, but `dp_world_size=1` for total batch size calculations
425+
* `sp_seq_length_is_variable` determines how sequence lengths are handled. When set to `True` (recommended), the implementation adapts to varying sequence lengths between batches. When `False`, all sequences must be padded to a fixed length specified by `sp_seq_length`.
426+
* `sp_attn_implementation` specifies the attention implementation to use. Supported values are `"sdpa"`, `"flash_attention_2"`, or `"flash_attention_3"`. Flash Attention is recommended for best performance, especially with multiple samples in a batch, because SDPA may incorrectly attend across sample boundaries.
427+
428+
> [!WARNING]
429+
> Sequence parallelism requires your model to use one of the supported attention implementations (`sdpa`, `flash_attention_2`, or `flash_attention_3`). The `eager` attention implementation is not supported because it doesn't properly handle `position_ids`.
430+
431+
When using sequence parallelism, ensure your sequences are properly padded. Use `pad_to_multiple_of` in your data collator to ensure sequences are divisible by `sp_size`. For example, with `sp_size=4`, set `pad_to_multiple_of=4` or higher.
432+
433+
```py
434+
from transformers import DataCollatorForLanguageModeling
435+
436+
data_collator = DataCollatorForLanguageModeling(
437+
tokenizer=tokenizer,
438+
mlm=False,
439+
pad_to_multiple_of=4, # Ensure sequences are divisible by sp_size
440+
)
441+
```
442+
443+
When using `sp_size` with multiple GPUs, you **must** explicitly set `dp_replicate_size` or `dp_shard_size` to ensure `total_size = dp_replicate_size * dp_shard_size * sp_size` equals your total number of GPUs. For example, with 8 GPUs and `sp_size=4`, you must set `dp_replicate_size=2` (since 2 × 1 × 4 = 8):
444+
445+
```py
446+
parallelism_config = ParallelismConfig(
447+
sp_backend="deepspeed",
448+
sp_size=4,
449+
dp_replicate_size=2,
450+
sp_handler=DeepSpeedSequenceParallelConfig(
451+
sp_seq_length_is_variable=True,
452+
sp_attn_implementation="flash_attention_2",
453+
),
454+
)
455+
```
456+
457+
[`Trainer`] automatically handles the special requirements for sequence parallelism including:
458+
459+
* Adapting the data loader via DeepSpeed's [`UlyssesSPDataLoaderAdapter`](https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/runtime/sequence_parallel/ulysses_sp.py) to shard sequences across GPUs. **Important**: Unlike Tensor Parallelism (where all ranks must receive identical data), each rank with SP receives a unique data stream from the DataLoader (similar to DP). The adapter handles distributing sequence chunks across SP ranks internally, so your DataLoader should continue feeding different samples to each rank.
460+
* Generating `position_ids` when not provided
461+
* Creating `shift_labels` for causal language modeling
462+
* Aggregating loss across sequence parallel ranks with proper masking for `-100` labels
463+
464+
You can launch training with sequence parallelism using the `accelerate launch` command.
465+
466+
```bash
467+
accelerate launch --config_file alst_config.yaml your_training_script.py \
468+
--output_dir output_dir \
469+
--per_device_train_batch_size 1 \
470+
--gradient_accumulation_steps 1
471+
```
472+
371473
## Training features
372474
373475
DeepSpeed supports many training features that can be configured in the config file. This section describes some of the most important features.

docs/source/en/model_doc/pix2struct.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ The original code can be found [here](https://github.com/google-research/pix2str
6565
[[autodoc]] Pix2StructImageProcessor
6666
- preprocess
6767

68+
## Pix2StructImageProcessorFast
69+
70+
[[autodoc]] Pix2StructImageProcessorFast
71+
- preprocess
72+
6873
## Pix2StructTextModel
6974

7075
[[autodoc]] Pix2StructTextModel

docs/source/en/model_doc/sam3_video.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,39 @@ Processed 51 frames
9797
>>> print(f"Masks shape: {frame_0_outputs['masks'].shape}")
9898
```
9999

100+
You can also track multiple object categories simultaneously by providing multiple prompts. The model efficiently reuses vision features across all prompts:
101+
102+
```python
103+
>>> # Add multiple text prompts (or use a list in add_text_prompt)
104+
>>> multi_prompt_session = processor.init_video_session(
105+
... video=video_frames,
106+
... inference_device=device,
107+
... processing_device="cpu",
108+
... video_storage_device="cpu",
109+
... dtype=torch.bfloat16,
110+
... )
111+
>>>
112+
>>> prompts = ["person", "bed", "lamp"]
113+
>>> processor.add_text_prompt(multi_prompt_session, prompts)
114+
>>>
115+
>>> # Process video - detects objects from ALL prompts in a single pass
116+
>>> multi_outputs_per_frame = {}
117+
>>> for model_outputs in model.propagate_in_video_iterator(
118+
... inference_session=multi_prompt_session, max_frame_num_to_track=50
119+
... ):
120+
... processed_outputs = processor.postprocess_outputs(multi_prompt_session, model_outputs)
121+
... multi_outputs_per_frame[model_outputs.frame_idx] = processed_outputs
122+
>>>
123+
>>> # Check which objects were detected by each prompt
124+
>>> frame_0_outputs = multi_outputs_per_frame[0]
125+
>>> prompt_to_obj_ids = frame_0_outputs["prompt_to_obj_ids"]
126+
>>> for prompt, obj_ids in prompt_to_obj_ids.items():
127+
... print(f"{prompt}: {len(obj_ids)} objects")
128+
person: 2 objects
129+
bed: 1 objects
130+
lamp: 1 objects
131+
```
132+
100133
#### Streaming Video Inference
101134

102135
<div class="warning">

src/transformers/cli/add_fast_image_processor.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,6 @@ def add_fast_image_processor(
5959
image_processor_name = re.findall(r"class (\w*ImageProcessor)", content_base_file)
6060
if not image_processor_name:
6161
raise ValueError(f"No ImageProcessor class found in {image_processing_module_file}")
62-
elif len(image_processor_name) > 1:
63-
raise ValueError(f"Multiple ImageProcessor classes found in {image_processing_module_file}")
6462

6563
image_processor_name = image_processor_name[0]
6664
fast_image_processor_name = image_processor_name + "Fast"

src/transformers/conversion_mapping.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from copy import deepcopy
1717

18-
from .core_model_loading import Concatenate, MergeModulelist, WeightConverter
18+
from .core_model_loading import Concatenate, MergeModulelist, WeightConverter, WeightRenaming
1919
from .utils import is_torch_available
2020

2121

@@ -26,6 +26,7 @@
2626
def _build_checkpoint_conversion_mapping():
2727
mapping = {
2828
"mixtral": [
29+
WeightRenaming(".block_sparse_moe.gate", ".mlp.gate"),
2930
WeightConverter(
3031
source_keys=[
3132
"block_sparse_moe.experts.*.w1.weight",
@@ -50,12 +51,6 @@ def _build_checkpoint_conversion_mapping():
5051
), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors
5152
], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first
5253
),
53-
# WeightConverter(
54-
# ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
55-
# "self_attn.qkv_proj",
56-
# operations=[Concatenate(dim=0)], # more like stack?
57-
# ),
58-
WeightConverter("*.block_sparse_moe.", "*.mlp."),
5954
],
6055
"qwen2_moe": [
6156
WeightConverter(
@@ -73,34 +68,34 @@ def _build_checkpoint_conversion_mapping():
7368
),
7469
],
7570
"legacy": [
76-
WeightConverter(
71+
WeightRenaming(
7772
source_keys="LayerNorm.gamma",
7873
target_keys="LayerNorm.weight",
7974
),
80-
WeightConverter(
75+
WeightRenaming(
8176
source_keys="LayerNorm.beta",
8277
target_keys="LayerNorm.bias",
8378
),
8479
],
8580
}
8681
if hasattr(torch.nn.utils.parametrizations, "weight_norm"):
8782
mapping["legacy"] += [
88-
WeightConverter(
83+
WeightRenaming(
8984
source_keys="weight_g",
9085
target_keys="parametrizations.weight.original0",
9186
),
92-
WeightConverter(
87+
WeightRenaming(
9388
source_keys="weight_v",
9489
target_keys="parametrizations.weight.original1",
9590
),
9691
]
9792
else:
9893
mapping["legacy"] += [
99-
WeightConverter(
94+
WeightRenaming(
10095
source_keys="parametrizations.weight.original0",
10196
target_keys="weight_g",
10297
),
103-
WeightConverter(
98+
WeightRenaming(
10499
source_keys="parametrizations.weight.original1",
105100
target_keys="weight_v",
106101
),

0 commit comments

Comments
 (0)