Skip to content

Commit b7a964a

Browse files
fix galore (#5957)
1 parent e091c7f commit b7a964a

File tree

4 files changed

+19
-20
lines changed

4 files changed

+19
-20
lines changed

swift/trainers/optimizers/galore/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def get_optimizer(args: TrainingArguments, config: GaLoreConfig) -> Tuple[Any, A
185185
from .adafactor import GaLoreAdafactor
186186
optimizer_cls = GaLoreAdafactor
187187
optimizer_kwargs.update({'scale_parameter': False, 'relative_step': False})
188-
elif args.optim in ('adamw_hf', 'adamw_torch'):
188+
elif args.optim in ('adamw_hf', 'adamw_torch', 'adamw_torch_fused'):
189189
if config.quantize:
190190
assert importlib.util.find_spec('q_galore_torch') is not None, \
191191
'Please install q-galore by `pip install q_galore_torch`'

swift/trainers/sequence_parallel/ulysses.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from transformers import PreTrainedTokenizer
1010

1111
from swift.llm import HfConfigFactory, get_llm_model
12-
from ...utils import get_device, get_dist_setting
12+
from swift.utils import get_cu_seqlens_from_position_ids, get_device, get_dist_setting
1313
from .utils import GatherLoss
1414

1515

@@ -252,7 +252,7 @@ def _attention(query, key, value, *args, **kwargs):
252252
if self.rp_world_size is not None and self.rp_world_size > 1:
253253
from .zigzag_ring_attn import zigzag_ring_flash_attn_varlen_func
254254
position_ids = kwargs['position_ids']
255-
cu_seqlens = self.get_cu_seqlens_from_position_ids(position_ids).to(torch.int32)
255+
cu_seqlens = get_cu_seqlens_from_position_ids(position_ids).to(torch.int32)
256256
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
257257
position_ids = self._split_packed(position_ids, cu_seqlens)
258258
mask = position_ids != -1
@@ -430,7 +430,7 @@ def _do_pad(tensor):
430430
return tensor
431431

432432
if position_ids is not None and self.rp_world_size > 1:
433-
cu_seqlens = self.get_cu_seqlens_from_position_ids(position_ids)
433+
cu_seqlens = get_cu_seqlens_from_position_ids(position_ids)
434434
all_tensors = []
435435
for i in range(len(cu_seqlens) - 1):
436436
if dim == 1:
@@ -468,7 +468,7 @@ def gather(self, local_output, dim: int, position_ids=None):
468468
gathered_rp = [torch.zeros_like(rp_chunk) for _ in range(self.rp_world_size)]
469469
torch.distributed.all_gather(gathered_rp, rp_chunk, group=self.rp_group)
470470

471-
cu_seqlens = self.get_cu_seqlens_from_position_ids(position_ids)
471+
cu_seqlens = get_cu_seqlens_from_position_ids(position_ids)
472472
all_tensor_length = []
473473
for i in range(len(cu_seqlens) - 1):
474474
length = cu_seqlens[i + 1] - cu_seqlens[i]
@@ -501,17 +501,6 @@ def gather(self, local_output, dim: int, position_ids=None):
501501
gathered_sp = torch.cat(gathered_sp.split(local_output.shape[0], dim=0), dim=dim)
502502
return gathered_sp.contiguous()
503503

504-
@staticmethod
505-
def get_cu_seqlens_from_position_ids(position_ids: torch.LongTensor):
506-
position_ids = position_ids[0]
507-
seq_start_indices = torch.where(position_ids == 0)[0]
508-
seq_end_indices = torch.cat(
509-
[seq_start_indices[1:],
510-
torch.tensor([len(position_ids)], device=position_ids.device)])
511-
seq_lengths = seq_end_indices - seq_start_indices
512-
cu_seqlens = torch.cumsum(torch.cat([torch.tensor([0], device=position_ids.device), seq_lengths]), dim=0)
513-
return cu_seqlens
514-
515504
def _split_packed(self, value, cu_seqlens, dim=1):
516505
"""Split and re-group in zigzag"""
517506
local_values = []
@@ -538,7 +527,7 @@ def split(self, input, dim: int, position_ids=None):
538527
if self.rp_world_size > 1:
539528
input_dim = input.dim()
540529
assert input_dim >= 2
541-
cu_seqlens = self.get_cu_seqlens_from_position_ids(position_ids)
530+
cu_seqlens = get_cu_seqlens_from_position_ids(position_ids)
542531
assert torch.all(cu_seqlens % (2 * self.rp_world_size) == 0)
543532
value_chunks = self._split_packed(input, cu_seqlens, dim=dim)
544533
local_value = value_chunks.chunk(self.sp_world_size, dim=dim)[self.sp_rank].contiguous()

swift/utils/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
from .tb_utils import TB_COLOR, TB_COLOR_SMOOTH, plot_images, read_tensorboard_file, tensorboard_smoothing
1212
from .torch_utils import (Serializer, activate_parameters, check_shared_disk, disable_safe_ddp_context_use_barrier,
1313
empty_cache, find_all_linears, find_embedding, find_layers, find_norm, freeze_parameters,
14-
gc_collect, get_current_device, get_device, get_device_count, get_model_parameter_info,
15-
get_n_params_grads, init_process_group, safe_ddp_context, seed_worker, set_default_ddp_config,
16-
set_device, show_layers, time_synchronize, unwrap_model_for_generation)
14+
gc_collect, get_cu_seqlens_from_position_ids, get_current_device, get_device,
15+
get_device_count, get_model_parameter_info, get_n_params_grads, init_process_group,
16+
safe_ddp_context, seed_worker, set_default_ddp_config, set_device, show_layers,
17+
time_synchronize, unwrap_model_for_generation)
1718
from .utils import (add_version_to_work_dir, check_json_format, copy_files_by_pattern, deep_getattr, find_free_port,
1819
format_time, get_env_args, import_external_file, json_parse_to_dict, lower_bound, parse_args,
1920
patch_getattr, read_multi_line, remove_response, seed_everything, split_list, subprocess_run,

swift/utils/torch_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,15 @@ def gc_collect() -> None:
364364
empty_cache()
365365

366366

367+
def get_cu_seqlens_from_position_ids(position_ids: torch.LongTensor):
368+
position_ids = position_ids[0]
369+
seq_start_indices = torch.where(position_ids == 0)[0]
370+
seq_end_indices = torch.cat([seq_start_indices[1:], torch.tensor([len(position_ids)], device=position_ids.device)])
371+
seq_lengths = seq_end_indices - seq_start_indices
372+
cu_seqlens = torch.cumsum(torch.cat([torch.tensor([0], device=position_ids.device), seq_lengths]), dim=0)
373+
return cu_seqlens
374+
375+
367376
class Serializer:
368377

369378
@staticmethod

0 commit comments

Comments
 (0)