Skip to content

Commit 2f9590c

Browse files
authored
support fope (#4043)
* fope * update config format * update fope params * merge main
1 parent 863ba74 commit 2f9590c

File tree

3 files changed

+200
-6
lines changed

3 files changed

+200
-6
lines changed

lmdeploy/pytorch/backends/default/rotary_embedding.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import math
44

55
import torch
6+
import torch.nn.functional as F
67
from torch import nn
78

8-
from ..rotary_embedding import (Llama3Parameters, LongRoPEScalingParameters, RopeType, RotaryEmbeddingBuilder,
9-
RotaryEmbeddingImpl, YarnParameters)
9+
from ..rotary_embedding import (FopeParameters, Llama3Parameters, LongRoPEScalingParameters, RopeType,
10+
RotaryEmbeddingBuilder, RotaryEmbeddingImpl, YarnParameters)
1011

1112

1213
def _rotary_embedding_fwd(position_ids: torch.Tensor,
@@ -270,6 +271,64 @@ def forward(self, x: torch.Tensor, position_ids: torch.Tensor):
270271
device_type=device)
271272

272273

274+
class FopeRotaryEmbeddingImpl(RotaryEmbeddingImpl):
275+
276+
def __init__(self,
277+
dim: int,
278+
max_position_embeddings: int = 4096,
279+
scaling_factor: float = 1.0,
280+
params: FopeParameters = None):
281+
super().__init__(dim, scaling_factor=scaling_factor)
282+
self.head_dim = dim
283+
self.max_position_embeddings = max_position_embeddings
284+
self.attention_scaling = scaling_factor
285+
self.params = params
286+
287+
inv_freq = self.params.inv_freq
288+
inv_freq_idx_selected = inv_freq > 2 * torch.pi / self.max_position_embeddings
289+
if self.params.num_inv_freq is not None and inv_freq_idx_selected.sum() > (inv_freq.shape[-1] -
290+
self.params.num_inv_freq):
291+
inv_freq_idx_selected[-self.params.num_inv_freq:] = False
292+
self.inv_freq = inv_freq[inv_freq_idx_selected]
293+
self.register_buffer('inv_freq', self.inv_freq, persistent=False)
294+
295+
def forward(self, x: torch.Tensor, position_ids: torch.Tensor, sin_coef: torch.Tensor, cos_coef: torch.Tensor):
296+
"""forward."""
297+
if self.inv_freq.device != x.device:
298+
self.inv_freq = self.inv_freq.to(x.device)
299+
300+
inv_freq = self.inv_freq
301+
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
302+
position_ids_expanded = position_ids[:, None, :].float()
303+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
304+
305+
batch_size, seq_len, _ = x.shape
306+
if self.params.fope_sep_head:
307+
pos_cos = freqs.cos().unsqueeze(1).expand(batch_size, self.params.num_key_value_heads, seq_len, -1)
308+
pos_sin = freqs.sin().unsqueeze(1).expand(batch_size, self.params.num_key_value_heads, seq_len, -1)
309+
else:
310+
pos_cos = freqs.cos()
311+
pos_sin = freqs.sin()
312+
313+
if self.params.fope_sep_head:
314+
sin = torch.einsum('bhtD, hDd -> bthd', pos_sin, sin_coef.float())
315+
cos = torch.einsum('bhtD, hDd -> bthd', pos_cos, cos_coef.float())
316+
else:
317+
sin = torch.einsum('btD, Dd -> btd', pos_sin, sin_coef.float())
318+
cos = torch.einsum('btD, Dd -> btd', pos_cos, cos_coef.float())
319+
320+
sin = F.pad(input=sin, pad=(0, self.head_dim // 2 - sin.size(-1)), mode='constant', value=1)
321+
cos = F.pad(input=cos, pad=(0, self.head_dim // 2 - cos.size(-1)), mode='constant', value=1)
322+
323+
sin = torch.cat((sin, sin), dim=-1)
324+
cos = torch.cat((cos, cos), dim=-1)
325+
326+
cos = cos * self.attention_scaling
327+
sin = sin * self.attention_scaling
328+
329+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
330+
331+
273332
class DefaultRotaryEmbeddingBuilder(RotaryEmbeddingBuilder):
274333
"""Rotary embedding builder."""
275334

@@ -282,6 +341,7 @@ def build(
282341
yarn_params: YarnParameters = None,
283342
longrope_params: LongRoPEScalingParameters = None,
284343
llama3_params: Llama3Parameters = None,
344+
fope_params: FopeParameters = None,
285345
emb_type: RopeType = RopeType.Default,
286346
):
287347
"""build."""
@@ -302,5 +362,12 @@ def build(
302362
max_position_embeddings=max_position_embeddings,
303363
longrope_params=longrope_params,
304364
)
365+
elif emb_type == RopeType.Fope:
366+
return FopeRotaryEmbeddingImpl(
367+
dim,
368+
max_position_embeddings=max_position_embeddings,
369+
scaling_factor=scaling_factor,
370+
params=fope_params,
371+
)
305372
else:
306373
raise NotImplementedError(f'Unsupported embedding type: {emb_type}')

lmdeploy/pytorch/backends/rotary_embedding.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from enum import Enum, auto
55
from typing import List
66

7+
import torch
8+
79

810
class RopeType(Enum):
911
"""Rotary embedding type."""
@@ -13,6 +15,7 @@ class RopeType(Enum):
1315
Llama3 = auto()
1416
Yarn = auto()
1517
LongRoPEScaling = auto()
18+
Fope = auto()
1619

1720

1821
@dataclass
@@ -43,11 +46,20 @@ class Llama3Parameters:
4346
original_max_position_embeddings: int = 8192
4447

4548

49+
@dataclass
50+
class FopeParameters:
51+
"""Fope parameters."""
52+
num_inv_freq: int = None
53+
num_key_value_heads: int = 1
54+
fope_sep_head: bool = False
55+
inv_freq: torch.Tensor = None
56+
57+
4658
class RotaryEmbeddingImpl(ABC):
4759
"""Rotary embedding implementation api."""
4860

4961
@abstractmethod
50-
def forward(self, x, position_ids):
62+
def forward(self, x, position_ids, **kwargs):
5163
"""forward."""
5264
raise NotImplementedError
5365

@@ -65,6 +77,7 @@ def build(
6577
yarn_params: YarnParameters = None,
6678
longrope_params: LongRoPEScalingParameters = None,
6779
llama3_params: Llama3Parameters = None,
80+
fope_params: FopeParameters = None,
6881
emb_type: RopeType = RopeType.Default,
6982
):
7083
"""build."""

lmdeploy/pytorch/nn/rotary_embedding.py

Lines changed: 117 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33
import math
44

5+
import torch
56
from torch import Tensor, nn
67
from transformers import PretrainedConfig
78

89
from ..backends import OpType, get_backend
9-
from ..backends.rotary_embedding import Llama3Parameters, LongRoPEScalingParameters, RopeType, YarnParameters
10+
from ..backends.rotary_embedding import (FopeParameters, Llama3Parameters, LongRoPEScalingParameters, RopeType,
11+
YarnParameters)
1012

1113

1214
def _get_default_rope_parameters(config: PretrainedConfig):
@@ -92,6 +94,23 @@ def _get_llama3_parameters(config: PretrainedConfig):
9294
return dict(emb_type=RopeType.Llama3, scaling_factor=scaling_factor, llama3_params=params)
9395

9496

97+
def _get_fope_parameters(config: PretrainedConfig):
98+
"""Get fope parameters."""
99+
# check if fope is used
100+
rope_scaling = getattr(config, 'rope_scaling', dict())
101+
fope_keys = ['fope_sep_head', 'fope_num_inv_freq']
102+
is_fope = any(key in rope_scaling for key in fope_keys)
103+
if not is_fope:
104+
return dict()
105+
106+
params = FopeParameters()
107+
rope_scaling = config.rope_scaling
108+
params.num_inv_freq = rope_scaling.get('fope_num_inv_freq', rope_scaling.get('num_inv_freq', params.num_inv_freq))
109+
params.num_key_value_heads = config.num_key_value_heads
110+
params.fope_sep_head = rope_scaling['fope_sep_head']
111+
return dict(fope_params=params)
112+
113+
95114
def build_rotary_params(config: PretrainedConfig):
96115
"""Get scaling_factor rotary params, and emb_type."""
97116
params = dict(emb_type=RopeType.Default)
@@ -100,6 +119,8 @@ def build_rotary_params(config: PretrainedConfig):
100119
if rope_scaling is not None:
101120
# BC: "rope_type" was originally "type"
102121
rope_type_str = config.rope_scaling.get('rope_type', config.rope_scaling.get('type', 'default'))
122+
if rope_type_str == 'fope':
123+
rope_type_str = 'default'
103124
build_funcs = dict(default=_get_default_rope_parameters,
104125
linear=_get_linear_scaling_rope_parameters,
105126
dynamic=_get_dynamic_ntk_parameters,
@@ -108,6 +129,7 @@ def build_rotary_params(config: PretrainedConfig):
108129
su=_get_longrope_parameters,
109130
llama3=_get_llama3_parameters)
110131
params.update(build_funcs[rope_type_str](config))
132+
params.update(_get_fope_parameters(config))
111133

112134
# update partial_rotary_factor
113135
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, 'partial_rotary_factor') else None
@@ -124,6 +146,7 @@ def build_rotary_embedding(dim: int,
124146
yarn_params: YarnParameters = None,
125147
longrope_params: LongRoPEScalingParameters = None,
126148
llama3_params: Llama3Parameters = None,
149+
fope_params: FopeParameters = None,
127150
emb_type: RopeType = RopeType.Default,
128151
partial_rotary_factor: float = None) -> nn.Module:
129152
"""Build rotary embedding op."""
@@ -134,7 +157,7 @@ def build_rotary_embedding(dim: int,
134157
# update rope_dim
135158
if partial_rotary_factor is not None:
136159
dim = int(dim * partial_rotary_factor)
137-
return builder.build(dim,
160+
impl = builder.build(dim,
138161
max_position_embeddings,
139162
base,
140163
scaling_factor,
@@ -143,6 +166,14 @@ def build_rotary_embedding(dim: int,
143166
llama3_params=llama3_params,
144167
emb_type=emb_type)
145168

169+
if fope_params is not None:
170+
inv_freq = impl.inv_freq
171+
fope_params.inv_freq = inv_freq
172+
fope = FopeRotaryEmbedding(dim, max_position_embeddings, scaling_factor, fope_params)
173+
return fope
174+
175+
return impl
176+
146177

147178
def build_rotary_embedding_from_config(config: PretrainedConfig) -> nn.Module:
148179
"""Build rotary embedding op from config."""
@@ -169,4 +200,87 @@ def __init__(self):
169200

170201
def forward(self, query: Tensor, key: Tensor, cos: Tensor, sin: Tensor, inplace: bool = True):
171202
"""forward."""
172-
return self.impl.forward(query, key, cos, sin, inplace)
203+
204+
assert query.dim() == key.dim() == 3, 'Expected query key (seq_len, heads, head_dim)'
205+
assert cos.dim() <= 3 and sin.dim() <= 3
206+
207+
need_reshape = False
208+
if cos.dim() == 3:
209+
# for fope
210+
need_reshape = True
211+
query_shape = query.shape
212+
key_shape = key.shape
213+
cos = cos.flatten(0, 1)
214+
sin = sin.flatten(0, 1)
215+
seq_len = cos.size(0)
216+
query = query.view(seq_len, -1, query.size(-1))
217+
key = key.view(seq_len, -1, key.size(-1))
218+
219+
query, key = self.impl.forward(query, key, cos, sin, inplace)
220+
221+
if need_reshape:
222+
query = query.view(query_shape)
223+
key = key.view(key_shape)
224+
return query, key
225+
226+
227+
class FopeRotaryEmbedding(nn.Module):
228+
"""Fope rotary embedding."""
229+
230+
def __init__(self, dim: int, max_position_embeddings: int, attention_scaling: float, params: FopeParameters):
231+
super().__init__()
232+
233+
num_key_value_heads, tp = self.update_num_kv_heads(params.num_key_value_heads)
234+
self.tp = tp
235+
params.num_key_value_heads = num_key_value_heads
236+
237+
# build impl
238+
backend = get_backend()
239+
builder = backend.get_layer_impl_builder(OpType.RotaryEmbedding)
240+
self.impl = builder.build(dim,
241+
max_position_embeddings=max_position_embeddings,
242+
scaling_factor=attention_scaling,
243+
fope_params=params,
244+
emb_type=RopeType.Fope)
245+
246+
# setup params
247+
inv_freq = self.impl.inv_freq
248+
self.input_dim = inv_freq.shape[-1]
249+
self.output_dim = inv_freq.shape[-1]
250+
self.cos_coef = nn.Parameter(torch.empty(num_key_value_heads, self.input_dim, self.output_dim),
251+
requires_grad=False)
252+
self.sin_coef = nn.Parameter(torch.empty(num_key_value_heads, self.input_dim, self.output_dim),
253+
requires_grad=False)
254+
if self.tp:
255+
self.cos_coef.weight_loader = self.weight_loader
256+
self.sin_coef.weight_loader = self.weight_loader
257+
258+
@staticmethod
259+
def update_num_kv_heads(num_key_value_heads: int):
260+
"""Update num_key_value_heads."""
261+
from lmdeploy.pytorch.distributed import get_dist_manager
262+
dist_mgr = get_dist_manager()
263+
dist_ctx = dist_mgr.current_context()
264+
tp = dist_ctx.dist_config.attn_tp
265+
# tp = dist_ctx.dist_config.attn_config.tp
266+
if tp > 1:
267+
num_key_value_heads = max(1, num_key_value_heads // tp)
268+
return num_key_value_heads, tp
269+
270+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
271+
"""Weight loader."""
272+
from lmdeploy.pytorch.distributed import get_tp_world_rank
273+
world_size, rank = get_tp_world_rank()
274+
num_key_value_heads = loaded_weight.size(0)
275+
276+
if num_key_value_heads < world_size:
277+
n_replicate = world_size // num_key_value_heads
278+
world_size = num_key_value_heads
279+
rank = rank // n_replicate
280+
281+
loaded_weight = loaded_weight.chunk(world_size, dim=0)[rank]
282+
param.copy_(loaded_weight)
283+
284+
def forward(self, x: Tensor, position_ids: Tensor):
285+
"""forward."""
286+
return self.impl.forward(x, position_ids, sin_coef=self.sin_coef, cos_coef=self.cos_coef)

0 commit comments

Comments
 (0)