Skip to content

Commit e833ed5

Browse files
authored
Merge branch 'main' into fix-pipeline-local-files
2 parents 228bd90 + bdee088 commit e833ed5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+770
-226
lines changed

docs/source/en/deepspeed.md

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,108 @@ The example ZeRO-3 and ZeRO-Infinity config below sets most of the parameter val
368368
}
369369
```
370370
371+
### Sequence Parallelism
372+
373+
DeepSpeed's ALST/Ulysses sequence parallelism enables training with very long sequences by splitting the sequence across multiple GPUs. This is particularly useful for training large language models with very long sequence lengths.
374+
375+
Arctic Long Sequence Training (ALST) uses a combination of sharding inputs along the sequence dimension and attention head parallelism. With this approach, you can train models with sequence lengths up to 500K tokens on a single H100 GPU, 3.7M on a single node, or 15M tokens on just four nodes with Llama-8B. The implementation described here enables one component of the full ALST system. For additional optimizations like TiledMLP and activation checkpoint offloading, refer to the [DeepSpeed ALST tutorial](https://www.deepspeed.ai/tutorials/ulysses-alst-sequence-parallelism/).
376+
377+
> [!TIP]
378+
> For more detailed information about sequence parallelism, see the Accelerate [Sequence Parallelism](https://huggingface.co/docs/accelerate/concept_guides/sequence_parallelism) guide.
379+
380+
To enable ALST/Ulysses sequence parallelism with [`Trainer`], configure `parallelism_config` in [`TrainingArguments`]. Sequence parallelism is configured via Accelerate's `ParallelismConfig` and requires an Accelerate version higher than 1.12.0.
381+
382+
```py
383+
from accelerate.utils import ParallelismConfig, DeepSpeedSequenceParallelConfig
384+
385+
# Example: 4 GPUs with sp_size=4, dp_replicate_size=1 (no data parallelism)
386+
# Ensure total_size = dp_replicate_size * dp_shard_size * sp_size = 1 * 1 * 4 = 4 GPUs
387+
parallelism_config = ParallelismConfig(
388+
sp_backend="deepspeed",
389+
sp_size=4, # Number of GPUs to split sequence across
390+
dp_replicate_size=1, # Explicit: no data parallelism
391+
sp_handler=DeepSpeedSequenceParallelConfig(
392+
sp_seq_length_is_variable=True,
393+
sp_attn_implementation="sdpa",
394+
),
395+
)
396+
397+
training_args = TrainingArguments(
398+
...,
399+
deepspeed="path/to/deepspeed_config.json",
400+
parallelism_config=parallelism_config,
401+
)
402+
```
403+
404+
You can also configure sequence parallelism using an Accelerate config file.
405+
406+
```yaml
407+
distributed_type: DEEPSPEED
408+
deepspeed_config:
409+
deepspeed_config_file: path/to/ds_config.json
410+
machine_rank: 0
411+
num_machines: 1
412+
num_processes: 4 # Total number of processes
413+
parallelism_config:
414+
parallelism_config_sp_size: 4 # Sequence parallel size
415+
parallelism_config_dp_replicate_size: 1 # Must be: dp_replicate_size * dp_shard_size * sp_size = num_processes
416+
parallelism_config_sp_backend: deepspeed
417+
parallelism_config_sp_seq_length_is_variable: true
418+
parallelism_config_sp_attn_implementation: sdpa
419+
```
420+
421+
Important configuration parameters include the following.
422+
423+
* `sp_backend` must be set to `"deepspeed"` to use ALST/Ulysses sequence parallelism.
424+
* `sp_size` is the degree of sequence parallelism. For example, `sp_size=4` means 4 GPUs will process a single sequence in parallel. You need at least 2 GPUs to enable sequence parallelism. **Data feeding**: Each rank receives a unique data stream from the DataLoader (like DP). **Batch size calculation**: The effective `dp_world_size = world_size / sp_size`. So with 4 GPUs and `sp_size=4`, each of the 4 ranks gets different samples from the DataLoader, but `dp_world_size=1` for total batch size calculations
425+
* `sp_seq_length_is_variable` determines how sequence lengths are handled. When set to `True` (recommended), the implementation adapts to varying sequence lengths between batches. When `False`, all sequences must be padded to a fixed length specified by `sp_seq_length`.
426+
* `sp_attn_implementation` specifies the attention implementation to use. Supported values are `"sdpa"`, `"flash_attention_2"`, or `"flash_attention_3"`. Flash Attention is recommended for best performance, especially with multiple samples in a batch, because SDPA may incorrectly attend across sample boundaries.
427+
428+
> [!WARNING]
429+
> Sequence parallelism requires your model to use one of the supported attention implementations (`sdpa`, `flash_attention_2`, or `flash_attention_3`). The `eager` attention implementation is not supported because it doesn't properly handle `position_ids`.
430+
431+
When using sequence parallelism, ensure your sequences are properly padded. Use `pad_to_multiple_of` in your data collator to ensure sequences are divisible by `sp_size`. For example, with `sp_size=4`, set `pad_to_multiple_of=4` or higher.
432+
433+
```py
434+
from transformers import DataCollatorForLanguageModeling
435+
436+
data_collator = DataCollatorForLanguageModeling(
437+
tokenizer=tokenizer,
438+
mlm=False,
439+
pad_to_multiple_of=4, # Ensure sequences are divisible by sp_size
440+
)
441+
```
442+
443+
When using `sp_size` with multiple GPUs, you **must** explicitly set `dp_replicate_size` or `dp_shard_size` to ensure `total_size = dp_replicate_size * dp_shard_size * sp_size` equals your total number of GPUs. For example, with 8 GPUs and `sp_size=4`, you must set `dp_replicate_size=2` (since 2 × 1 × 4 = 8):
444+
445+
```py
446+
parallelism_config = ParallelismConfig(
447+
sp_backend="deepspeed",
448+
sp_size=4,
449+
dp_replicate_size=2,
450+
sp_handler=DeepSpeedSequenceParallelConfig(
451+
sp_seq_length_is_variable=True,
452+
sp_attn_implementation="flash_attention_2",
453+
),
454+
)
455+
```
456+
457+
[`Trainer`] automatically handles the special requirements for sequence parallelism including:
458+
459+
* Adapting the data loader via DeepSpeed's [`UlyssesSPDataLoaderAdapter`](https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/runtime/sequence_parallel/ulysses_sp.py) to shard sequences across GPUs. **Important**: Unlike Tensor Parallelism (where all ranks must receive identical data), each rank with SP receives a unique data stream from the DataLoader (similar to DP). The adapter handles distributing sequence chunks across SP ranks internally, so your DataLoader should continue feeding different samples to each rank.
460+
* Generating `position_ids` when not provided
461+
* Creating `shift_labels` for causal language modeling
462+
* Aggregating loss across sequence parallel ranks with proper masking for `-100` labels
463+
464+
You can launch training with sequence parallelism using the `accelerate launch` command.
465+
466+
```bash
467+
accelerate launch --config_file alst_config.yaml your_training_script.py \
468+
--output_dir output_dir \
469+
--per_device_train_batch_size 1 \
470+
--gradient_accumulation_steps 1
471+
```
472+
371473
## Training features
372474
373475
DeepSpeed supports many training features that can be configured in the config file. This section describes some of the most important features.

src/transformers/core_model_loading.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,10 @@ def convert(self, layer_name: str, config=None, quantizer=None, missing_keys: Op
359359
return collected_tensors, misc
360360

361361

362-
GLOBAL_WORKERS = min(16, (os.cpu_count() or 8) * 2) # NVMe: 8-16; HDD/NFS: 2-4
362+
# For I/O bound operations (i.e. here reading files), it is better to have fewer threads, e.g. 4 is a good default.
363+
# Having too many is actually harming performances quite a lot, i.e. using 16 can sometimes lead to taking TWICE
364+
# as much time to load the same model
365+
GLOBAL_WORKERS = min(4, os.cpu_count() or 4)
363366

364367

365368
def _materialize_copy(tensor, device=None, dtype=None):
@@ -610,7 +613,7 @@ def convert_and_load_state_dict_in_model(
610613
tp_plan = tp_plan or {}
611614
device_map = device_map or {"": "cpu"}
612615
device_map_regex = re.compile(
613-
"|".join(rf"({k})" for k in sorted(device_map.keys(), key=lambda x: x.count("."), reverse=True))
616+
"|".join(rf"({k})" for k in sorted(device_map.keys(), key=lambda x: (x.count("."), len(x)), reverse=True))
614617
)
615618
dtype_plan = dtype_plan or {}
616619
weight_mapping = weight_mapping or []

src/transformers/masking_utils.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -340,9 +340,6 @@ def sdpa_mask(
340340
allow_is_causal_skip (`bool`, optional):
341341
Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
342342
`torch.sdpa` instead. Default to `True`.
343-
allow_torch_fix (`bool`, optional):
344-
Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
345-
versions. We need an arg to skip it when using eager. By default `True`.
346343
allow_is_bidirectional_skip (`bool`, optional):
347344
Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
348345
i.e. full attention without any padding. Default to `False`.
@@ -480,6 +477,7 @@ def eager_mask(
480477
mask_function: Callable = causal_mask_function,
481478
attention_mask: Optional[torch.Tensor] = None,
482479
dtype: torch.dtype = torch.float32,
480+
allow_is_bidirectional_skip: bool = False,
483481
use_vmap: bool = False,
484482
**kwargs,
485483
) -> torch.Tensor:
@@ -503,13 +501,15 @@ def eager_mask(
503501
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
504502
dtype (`torch.dtype`, optional):
505503
The dtype to use for the mask. By default, `torch.float32`.
504+
allow_is_bidirectional_skip (`bool`, optional):
505+
Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
506+
i.e. full attention without any padding. Default to `False`.
506507
use_vmap (`bool`, optional):
507508
Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be
508509
index-based (for the cost of speed performance). By default `False`.
509510
"""
510511
# The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
511512
_ = kwargs.pop("allow_is_causal_skip", None)
512-
_ = kwargs.pop("allow_is_bidirectional_skip", None)
513513
_ = kwargs.pop("allow_torch_fix", None)
514514
mask = sdpa_mask(
515515
batch_size=batch_size,
@@ -519,14 +519,16 @@ def eager_mask(
519519
mask_function=mask_function,
520520
attention_mask=attention_mask,
521521
allow_is_causal_skip=False,
522-
allow_is_bidirectional_skip=False,
522+
allow_is_bidirectional_skip=allow_is_bidirectional_skip,
523523
allow_torch_fix=False,
524524
use_vmap=use_vmap,
525525
**kwargs,
526526
)
527-
min_dtype = torch.finfo(dtype).min
528-
# we need 0s where the tokens should be taken into account, and -inf otherwise (mask is already of boolean type)
529-
mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
527+
# only bidirectional masks can be skipped, otherwise we convert bool -> float
528+
if mask is not None:
529+
min_dtype = torch.finfo(dtype).min
530+
# we need 0s where the tokens should be taken into account, and -inf otherwise (mask is already of boolean type)
531+
mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
530532
return mask
531533

532534

src/transformers/modeling_flash_attention_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
is_flash_attn_3_available,
2525
is_flash_attn_greater_or_equal_2_10,
2626
is_torch_npu_available,
27+
is_torch_xpu_available,
2728
logging,
2829
)
2930

@@ -45,7 +46,12 @@ def flash_attn_supports_top_left_mask():
4546

4647
# TODO Deprecate when all models have the attention interface
4748
def is_flash_attn_available():
48-
return is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available()
49+
return (
50+
is_flash_attn_3_available()
51+
or is_flash_attn_2_available()
52+
or is_torch_npu_available()
53+
or is_torch_xpu_available()
54+
)
4955

5056

5157
# `globals()` is not compatible with dynamo, hence we have do define them in global scope ourselves
@@ -97,7 +103,7 @@ def _lazy_imports(implementation: Optional[str]):
97103
if flash_attn_varlen_func is None or flash_attn_func is None:
98104
raise ValueError(
99105
f"Could not find the currently requested flash attention implementation at `{implementation}`."
100-
f"Make sure that you request a valid kernel from the hub, e.g. `kernels-community/flash-attn`."
106+
f"Make sure that you request a valid kernel from the hub, e.g. `kernels-community/flash-attn2`."
101107
)
102108

103109
return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input

0 commit comments

Comments
 (0)