22
33import math
44
5+ import torch
56from torch import Tensor , nn
67from transformers import PretrainedConfig
78
89from ..backends import OpType , get_backend
9- from ..backends .rotary_embedding import Llama3Parameters , LongRoPEScalingParameters , RopeType , YarnParameters
10+ from ..backends .rotary_embedding import (FopeParameters , Llama3Parameters , LongRoPEScalingParameters , RopeType ,
11+ YarnParameters )
1012
1113
1214def _get_default_rope_parameters (config : PretrainedConfig ):
@@ -92,6 +94,23 @@ def _get_llama3_parameters(config: PretrainedConfig):
9294 return dict (emb_type = RopeType .Llama3 , scaling_factor = scaling_factor , llama3_params = params )
9395
9496
97+ def _get_fope_parameters (config : PretrainedConfig ):
98+ """Get fope parameters."""
99+ # check if fope is used
100+ rope_scaling = getattr (config , 'rope_scaling' , dict ())
101+ fope_keys = ['fope_sep_head' , 'fope_num_inv_freq' ]
102+ is_fope = any (key in rope_scaling for key in fope_keys )
103+ if not is_fope :
104+ return dict ()
105+
106+ params = FopeParameters ()
107+ rope_scaling = config .rope_scaling
108+ params .num_inv_freq = rope_scaling .get ('fope_num_inv_freq' , rope_scaling .get ('num_inv_freq' , params .num_inv_freq ))
109+ params .num_key_value_heads = config .num_key_value_heads
110+ params .fope_sep_head = rope_scaling ['fope_sep_head' ]
111+ return dict (fope_params = params )
112+
113+
95114def build_rotary_params (config : PretrainedConfig ):
96115 """Get scaling_factor rotary params, and emb_type."""
97116 params = dict (emb_type = RopeType .Default )
@@ -100,6 +119,8 @@ def build_rotary_params(config: PretrainedConfig):
100119 if rope_scaling is not None :
101120 # BC: "rope_type" was originally "type"
102121 rope_type_str = config .rope_scaling .get ('rope_type' , config .rope_scaling .get ('type' , 'default' ))
122+ if rope_type_str == 'fope' :
123+ rope_type_str = 'default'
103124 build_funcs = dict (default = _get_default_rope_parameters ,
104125 linear = _get_linear_scaling_rope_parameters ,
105126 dynamic = _get_dynamic_ntk_parameters ,
@@ -108,6 +129,7 @@ def build_rotary_params(config: PretrainedConfig):
108129 su = _get_longrope_parameters ,
109130 llama3 = _get_llama3_parameters )
110131 params .update (build_funcs [rope_type_str ](config ))
132+ params .update (_get_fope_parameters (config ))
111133
112134 # update partial_rotary_factor
113135 partial_rotary_factor = config .partial_rotary_factor if hasattr (config , 'partial_rotary_factor' ) else None
@@ -124,6 +146,7 @@ def build_rotary_embedding(dim: int,
124146 yarn_params : YarnParameters = None ,
125147 longrope_params : LongRoPEScalingParameters = None ,
126148 llama3_params : Llama3Parameters = None ,
149+ fope_params : FopeParameters = None ,
127150 emb_type : RopeType = RopeType .Default ,
128151 partial_rotary_factor : float = None ) -> nn .Module :
129152 """Build rotary embedding op."""
@@ -134,7 +157,7 @@ def build_rotary_embedding(dim: int,
134157 # update rope_dim
135158 if partial_rotary_factor is not None :
136159 dim = int (dim * partial_rotary_factor )
137- return builder .build (dim ,
160+ impl = builder .build (dim ,
138161 max_position_embeddings ,
139162 base ,
140163 scaling_factor ,
@@ -143,6 +166,14 @@ def build_rotary_embedding(dim: int,
143166 llama3_params = llama3_params ,
144167 emb_type = emb_type )
145168
169+ if fope_params is not None :
170+ inv_freq = impl .inv_freq
171+ fope_params .inv_freq = inv_freq
172+ fope = FopeRotaryEmbedding (dim , max_position_embeddings , scaling_factor , fope_params )
173+ return fope
174+
175+ return impl
176+
146177
147178def build_rotary_embedding_from_config (config : PretrainedConfig ) -> nn .Module :
148179 """Build rotary embedding op from config."""
@@ -169,4 +200,87 @@ def __init__(self):
169200
170201 def forward (self , query : Tensor , key : Tensor , cos : Tensor , sin : Tensor , inplace : bool = True ):
171202 """forward."""
172- return self .impl .forward (query , key , cos , sin , inplace )
203+
204+ assert query .dim () == key .dim () == 3 , 'Expected query key (seq_len, heads, head_dim)'
205+ assert cos .dim () <= 3 and sin .dim () <= 3
206+
207+ need_reshape = False
208+ if cos .dim () == 3 :
209+ # for fope
210+ need_reshape = True
211+ query_shape = query .shape
212+ key_shape = key .shape
213+ cos = cos .flatten (0 , 1 )
214+ sin = sin .flatten (0 , 1 )
215+ seq_len = cos .size (0 )
216+ query = query .view (seq_len , - 1 , query .size (- 1 ))
217+ key = key .view (seq_len , - 1 , key .size (- 1 ))
218+
219+ query , key = self .impl .forward (query , key , cos , sin , inplace )
220+
221+ if need_reshape :
222+ query = query .view (query_shape )
223+ key = key .view (key_shape )
224+ return query , key
225+
226+
227+ class FopeRotaryEmbedding (nn .Module ):
228+ """Fope rotary embedding."""
229+
230+ def __init__ (self , dim : int , max_position_embeddings : int , attention_scaling : float , params : FopeParameters ):
231+ super ().__init__ ()
232+
233+ num_key_value_heads , tp = self .update_num_kv_heads (params .num_key_value_heads )
234+ self .tp = tp
235+ params .num_key_value_heads = num_key_value_heads
236+
237+ # build impl
238+ backend = get_backend ()
239+ builder = backend .get_layer_impl_builder (OpType .RotaryEmbedding )
240+ self .impl = builder .build (dim ,
241+ max_position_embeddings = max_position_embeddings ,
242+ scaling_factor = attention_scaling ,
243+ fope_params = params ,
244+ emb_type = RopeType .Fope )
245+
246+ # setup params
247+ inv_freq = self .impl .inv_freq
248+ self .input_dim = inv_freq .shape [- 1 ]
249+ self .output_dim = inv_freq .shape [- 1 ]
250+ self .cos_coef = nn .Parameter (torch .empty (num_key_value_heads , self .input_dim , self .output_dim ),
251+ requires_grad = False )
252+ self .sin_coef = nn .Parameter (torch .empty (num_key_value_heads , self .input_dim , self .output_dim ),
253+ requires_grad = False )
254+ if self .tp :
255+ self .cos_coef .weight_loader = self .weight_loader
256+ self .sin_coef .weight_loader = self .weight_loader
257+
258+ @staticmethod
259+ def update_num_kv_heads (num_key_value_heads : int ):
260+ """Update num_key_value_heads."""
261+ from lmdeploy .pytorch .distributed import get_dist_manager
262+ dist_mgr = get_dist_manager ()
263+ dist_ctx = dist_mgr .current_context ()
264+ tp = dist_ctx .dist_config .attn_tp
265+ # tp = dist_ctx.dist_config.attn_config.tp
266+ if tp > 1 :
267+ num_key_value_heads = max (1 , num_key_value_heads // tp )
268+ return num_key_value_heads , tp
269+
270+ def weight_loader (self , param : nn .Parameter , loaded_weight : torch .Tensor ):
271+ """Weight loader."""
272+ from lmdeploy .pytorch .distributed import get_tp_world_rank
273+ world_size , rank = get_tp_world_rank ()
274+ num_key_value_heads = loaded_weight .size (0 )
275+
276+ if num_key_value_heads < world_size :
277+ n_replicate = world_size // num_key_value_heads
278+ world_size = num_key_value_heads
279+ rank = rank // n_replicate
280+
281+ loaded_weight = loaded_weight .chunk (world_size , dim = 0 )[rank ]
282+ param .copy_ (loaded_weight )
283+
284+ def forward (self , x : Tensor , position_ids : Tensor ):
285+ """forward."""
286+ return self .impl .forward (x , position_ids , sin_coef = self .sin_coef , cos_coef = self .cos_coef )
0 commit comments