Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion swift/trainers/optimizers/galore/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def get_optimizer(args: TrainingArguments, config: GaLoreConfig) -> Tuple[Any, A
from .adafactor import GaLoreAdafactor
optimizer_cls = GaLoreAdafactor
optimizer_kwargs.update({'scale_parameter': False, 'relative_step': False})
elif args.optim in ('adamw_hf', 'adamw_torch'):
elif args.optim in ('adamw_hf', 'adamw_torch', 'adamw_torch_fused'):
if config.quantize:
assert importlib.util.find_spec('q_galore_torch') is not None, \
'Please install q-galore by `pip install q_galore_torch`'
Expand Down
21 changes: 5 additions & 16 deletions swift/trainers/sequence_parallel/ulysses.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from transformers import PreTrainedTokenizer

from swift.llm import HfConfigFactory, get_llm_model
from ...utils import get_device, get_dist_setting
from swift.utils import get_cu_seqlens_from_position_ids, get_device, get_dist_setting
from .utils import GatherLoss


Expand Down Expand Up @@ -252,7 +252,7 @@ def _attention(query, key, value, *args, **kwargs):
if self.rp_world_size is not None and self.rp_world_size > 1:
from .zigzag_ring_attn import zigzag_ring_flash_attn_varlen_func
position_ids = kwargs['position_ids']
cu_seqlens = self.get_cu_seqlens_from_position_ids(position_ids).to(torch.int32)
cu_seqlens = get_cu_seqlens_from_position_ids(position_ids).to(torch.int32)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
position_ids = self._split_packed(position_ids, cu_seqlens)
mask = position_ids != -1
Expand Down Expand Up @@ -430,7 +430,7 @@ def _do_pad(tensor):
return tensor

if position_ids is not None and self.rp_world_size > 1:
cu_seqlens = self.get_cu_seqlens_from_position_ids(position_ids)
cu_seqlens = get_cu_seqlens_from_position_ids(position_ids)
all_tensors = []
for i in range(len(cu_seqlens) - 1):
if dim == 1:
Expand Down Expand Up @@ -468,7 +468,7 @@ def gather(self, local_output, dim: int, position_ids=None):
gathered_rp = [torch.zeros_like(rp_chunk) for _ in range(self.rp_world_size)]
torch.distributed.all_gather(gathered_rp, rp_chunk, group=self.rp_group)

cu_seqlens = self.get_cu_seqlens_from_position_ids(position_ids)
cu_seqlens = get_cu_seqlens_from_position_ids(position_ids)
all_tensor_length = []
for i in range(len(cu_seqlens) - 1):
length = cu_seqlens[i + 1] - cu_seqlens[i]
Expand Down Expand Up @@ -501,17 +501,6 @@ def gather(self, local_output, dim: int, position_ids=None):
gathered_sp = torch.cat(gathered_sp.split(local_output.shape[0], dim=0), dim=dim)
return gathered_sp.contiguous()

@staticmethod
def get_cu_seqlens_from_position_ids(position_ids: torch.LongTensor):
position_ids = position_ids[0]
seq_start_indices = torch.where(position_ids == 0)[0]
seq_end_indices = torch.cat(
[seq_start_indices[1:],
torch.tensor([len(position_ids)], device=position_ids.device)])
seq_lengths = seq_end_indices - seq_start_indices
cu_seqlens = torch.cumsum(torch.cat([torch.tensor([0], device=position_ids.device), seq_lengths]), dim=0)
return cu_seqlens

def _split_packed(self, value, cu_seqlens, dim=1):
"""Split and re-group in zigzag"""
local_values = []
Expand All @@ -538,7 +527,7 @@ def split(self, input, dim: int, position_ids=None):
if self.rp_world_size > 1:
input_dim = input.dim()
assert input_dim >= 2
cu_seqlens = self.get_cu_seqlens_from_position_ids(position_ids)
cu_seqlens = get_cu_seqlens_from_position_ids(position_ids)
assert torch.all(cu_seqlens % (2 * self.rp_world_size) == 0)
value_chunks = self._split_packed(input, cu_seqlens, dim=dim)
local_value = value_chunks.chunk(self.sp_world_size, dim=dim)[self.sp_rank].contiguous()
Expand Down
7 changes: 4 additions & 3 deletions swift/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from .tb_utils import TB_COLOR, TB_COLOR_SMOOTH, plot_images, read_tensorboard_file, tensorboard_smoothing
from .torch_utils import (Serializer, activate_parameters, check_shared_disk, disable_safe_ddp_context_use_barrier,
empty_cache, find_all_linears, find_embedding, find_layers, find_norm, freeze_parameters,
gc_collect, get_current_device, get_device, get_device_count, get_model_parameter_info,
get_n_params_grads, init_process_group, safe_ddp_context, seed_worker, set_default_ddp_config,
set_device, show_layers, time_synchronize, unwrap_model_for_generation)
gc_collect, get_cu_seqlens_from_position_ids, get_current_device, get_device,
get_device_count, get_model_parameter_info, get_n_params_grads, init_process_group,
safe_ddp_context, seed_worker, set_default_ddp_config, set_device, show_layers,
time_synchronize, unwrap_model_for_generation)
from .utils import (add_version_to_work_dir, check_json_format, copy_files_by_pattern, deep_getattr, find_free_port,
format_time, get_env_args, import_external_file, json_parse_to_dict, lower_bound, parse_args,
patch_getattr, read_multi_line, remove_response, seed_everything, split_list, subprocess_run,
Expand Down
9 changes: 9 additions & 0 deletions swift/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,15 @@ def gc_collect() -> None:
empty_cache()


def get_cu_seqlens_from_position_ids(position_ids: torch.LongTensor):
position_ids = position_ids[0]
seq_start_indices = torch.where(position_ids == 0)[0]
seq_end_indices = torch.cat([seq_start_indices[1:], torch.tensor([len(position_ids)], device=position_ids.device)])
seq_lengths = seq_end_indices - seq_start_indices
cu_seqlens = torch.cumsum(torch.cat([torch.tensor([0], device=position_ids.device), seq_lengths]), dim=0)
return cu_seqlens


class Serializer:

@staticmethod
Expand Down
Loading