1212
1313from lightllm .common .mem_utils import MemoryManager
1414
15+
1516class Gemma_2bTpPartModel (TpPartBaseModel ):
1617 # weight class
1718 pre_and_post_weight_class = Gemma_2bPreAndPostLayerWeight
@@ -38,20 +39,22 @@ def _verify_params(self):
3839 # assert self.config["num_key_value_heads"] % self.world_size_ == 0
3940 assert self .config ["num_attention_heads" ] % self .world_size_ == 0
4041 return
41-
42+
4243 def _init_custom (self ):
4344 self ._init_to_get_rotary ()
4445 return
45-
46+
4647 def _init_mem_manager (self ):
47- self .mem_manager = MemoryManager (self .max_total_token_num ,
48- dtype = self .data_type ,
49- head_num = self .config ["num_key_value_heads" ], # [SYM] always == 1
50- head_dim = self .config ["hidden_size" ] // self .config ["num_attention_heads" ],
51- layer_num = self .config ["num_hidden_layers" ])
48+ self .mem_manager = MemoryManager (
49+ self .max_total_token_num ,
50+ dtype = self .data_type ,
51+ head_num = self .config ["num_key_value_heads" ], # [SYM] always == 1
52+ head_dim = self .config ["hidden_size" ] // self .config ["num_attention_heads" ],
53+ layer_num = self .config ["num_hidden_layers" ],
54+ mem_fraction = self .mem_fraction ,
55+ )
5256 return
5357
54-
5558 def _init_to_get_rotary (self , default_base = 10000 ):
5659 if self .config .get ("rope_scaling" , {}) is None :
5760 rope_scaling_factor = 1.0
@@ -64,16 +67,16 @@ def _init_to_get_rotary(self, default_base=10000):
6467 max_seq_len = self .config ["max_sequence_length" ]
6568 else :
6669 max_position_embeddings = self .config .get (
67- "max_position_embeddings" ,
68- 2048 if base <= 10000.0 + 1e-5 else 16384
70+ "max_position_embeddings" , 2048 if base <= 10000.0 + 1e-5 else 16384
6971 )
7072 max_seq_len = max_position_embeddings * rope_scaling_factor
7173
72- inv_freq = 1.0 / (base ** (torch .arange (0 , self .head_dim_ , 2 , device = "cpu" , dtype = torch .float32 ) / self .head_dim_ ))
74+ inv_freq = 1.0 / (
75+ base ** (torch .arange (0 , self .head_dim_ , 2 , device = "cpu" , dtype = torch .float32 ) / self .head_dim_ )
76+ )
7377 t = torch .arange (max_seq_len + 1024 * 64 , device = "cpu" , dtype = torch .float32 ) / rope_scaling_factor
7478 freqs = torch .outer (t , inv_freq )
7579
7680 self ._cos_cached = torch .cos (freqs ).to (self .data_type ).cuda ()
7781 self ._sin_cached = torch .sin (freqs ).to (self .data_type ).cuda ()
7882 return
79-
0 commit comments