9
9
from transformers import PreTrainedTokenizer
10
10
11
11
from swift .llm import HfConfigFactory , get_llm_model
12
- from ... utils import get_device , get_dist_setting
12
+ from swift . utils import get_cu_seqlens_from_position_ids , get_device , get_dist_setting
13
13
from .utils import GatherLoss
14
14
15
15
@@ -252,7 +252,7 @@ def _attention(query, key, value, *args, **kwargs):
252
252
if self .rp_world_size is not None and self .rp_world_size > 1 :
253
253
from .zigzag_ring_attn import zigzag_ring_flash_attn_varlen_func
254
254
position_ids = kwargs ['position_ids' ]
255
- cu_seqlens = self . get_cu_seqlens_from_position_ids (position_ids ).to (torch .int32 )
255
+ cu_seqlens = get_cu_seqlens_from_position_ids (position_ids ).to (torch .int32 )
256
256
max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ().item ()
257
257
position_ids = self ._split_packed (position_ids , cu_seqlens )
258
258
mask = position_ids != - 1
@@ -430,7 +430,7 @@ def _do_pad(tensor):
430
430
return tensor
431
431
432
432
if position_ids is not None and self .rp_world_size > 1 :
433
- cu_seqlens = self . get_cu_seqlens_from_position_ids (position_ids )
433
+ cu_seqlens = get_cu_seqlens_from_position_ids (position_ids )
434
434
all_tensors = []
435
435
for i in range (len (cu_seqlens ) - 1 ):
436
436
if dim == 1 :
@@ -468,7 +468,7 @@ def gather(self, local_output, dim: int, position_ids=None):
468
468
gathered_rp = [torch .zeros_like (rp_chunk ) for _ in range (self .rp_world_size )]
469
469
torch .distributed .all_gather (gathered_rp , rp_chunk , group = self .rp_group )
470
470
471
- cu_seqlens = self . get_cu_seqlens_from_position_ids (position_ids )
471
+ cu_seqlens = get_cu_seqlens_from_position_ids (position_ids )
472
472
all_tensor_length = []
473
473
for i in range (len (cu_seqlens ) - 1 ):
474
474
length = cu_seqlens [i + 1 ] - cu_seqlens [i ]
@@ -501,17 +501,6 @@ def gather(self, local_output, dim: int, position_ids=None):
501
501
gathered_sp = torch .cat (gathered_sp .split (local_output .shape [0 ], dim = 0 ), dim = dim )
502
502
return gathered_sp .contiguous ()
503
503
504
- @staticmethod
505
- def get_cu_seqlens_from_position_ids (position_ids : torch .LongTensor ):
506
- position_ids = position_ids [0 ]
507
- seq_start_indices = torch .where (position_ids == 0 )[0 ]
508
- seq_end_indices = torch .cat (
509
- [seq_start_indices [1 :],
510
- torch .tensor ([len (position_ids )], device = position_ids .device )])
511
- seq_lengths = seq_end_indices - seq_start_indices
512
- cu_seqlens = torch .cumsum (torch .cat ([torch .tensor ([0 ], device = position_ids .device ), seq_lengths ]), dim = 0 )
513
- return cu_seqlens
514
-
515
504
def _split_packed (self , value , cu_seqlens , dim = 1 ):
516
505
"""Split and re-group in zigzag"""
517
506
local_values = []
@@ -538,7 +527,7 @@ def split(self, input, dim: int, position_ids=None):
538
527
if self .rp_world_size > 1 :
539
528
input_dim = input .dim ()
540
529
assert input_dim >= 2
541
- cu_seqlens = self . get_cu_seqlens_from_position_ids (position_ids )
530
+ cu_seqlens = get_cu_seqlens_from_position_ids (position_ids )
542
531
assert torch .all (cu_seqlens % (2 * self .rp_world_size ) == 0 )
543
532
value_chunks = self ._split_packed (input , cu_seqlens , dim = dim )
544
533
local_value = value_chunks .chunk (self .sp_world_size , dim = dim )[self .sp_rank ].contiguous ()
0 commit comments