@@ -91,11 +91,11 @@ def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attent
9191 self .dtypes .add (dtype )
9292 super ().__init__ (device = device , layer = layer , layer_idx = layer_idx , textmodel_json_config = {}, dtype = dtype , special_tokens = {"start" : 2 , "pad" : 0 }, layer_norm_hidden_state = False , model_class = comfy .text_encoders .llama .Gemma3_12B , enable_attention_masks = attention_mask , return_attention_masks = attention_mask , model_options = model_options )
9393
94- def generate (self , tokens , do_sample , max_length , temperature , top_k , top_p , min_p , repetition_penalty , seed ):
94+ def generate (self , tokens , do_sample , max_length , temperature , top_k , top_p , min_p , repetition_penalty , seed , presence_penalty ):
9595 tokens_only = [[t [0 ] for t in b ] for b in tokens ]
9696 embeds , _ , _ , embeds_info = self .process_tokens (tokens_only , self .execution_device )
9797 comfy .utils .normalize_image_embeddings (embeds , embeds_info , self .transformer .model .config .hidden_size ** 0.5 )
98- return self .transformer .generate (embeds , do_sample , max_length , temperature , top_k , top_p , min_p , repetition_penalty , seed , stop_tokens = [106 ]) # 106 is <end_of_turn>
98+ return self .transformer .generate (embeds , do_sample , max_length , temperature , top_k , top_p , min_p , repetition_penalty , seed , stop_tokens = [106 ], presence_penalty = presence_penalty ) # 106 is <end_of_turn>
9999
100100class DualLinearProjection (torch .nn .Module ):
101101 def __init__ (self , in_dim , out_dim_video , out_dim_audio , dtype = None , device = None , operations = None ):
@@ -189,8 +189,8 @@ def encode_token_weights(self, token_weight_pairs):
189189
190190 return out .to (device = out_device , dtype = torch .float ), pooled , extra
191191
192- def generate (self , tokens , do_sample , max_length , temperature , top_k , top_p , min_p , repetition_penalty , seed ):
193- return self .gemma3_12b .generate (tokens ["gemma3_12b" ], do_sample , max_length , temperature , top_k , top_p , min_p , repetition_penalty , seed )
192+ def generate (self , tokens , do_sample , max_length , temperature , top_k , top_p , min_p , repetition_penalty , seed , presence_penalty ):
193+ return self .gemma3_12b .generate (tokens ["gemma3_12b" ], do_sample , max_length , temperature , top_k , top_p , min_p , repetition_penalty , seed , presence_penalty )
194194
195195 def load_sd (self , sd ):
196196 if "model.layers.47.self_attn.q_norm.weight" in sd :
0 commit comments