Skip to content

Commit b0fd65e

Browse files
authored
fix: regression in text generate with LTXAV model (Comfy-Org#13170)
1 parent 2a1f402 commit b0fd65e

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

comfy/text_encoders/lt.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,11 @@ def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attent
9191
self.dtypes.add(dtype)
9292
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
9393

94-
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
94+
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
9595
tokens_only = [[t[0] for t in b] for b in tokens]
9696
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
9797
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
98-
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
98+
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
9999

100100
class DualLinearProjection(torch.nn.Module):
101101
def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None):
@@ -189,8 +189,8 @@ def encode_token_weights(self, token_weight_pairs):
189189

190190
return out.to(device=out_device, dtype=torch.float), pooled, extra
191191

192-
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
193-
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
192+
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
193+
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty)
194194

195195
def load_sd(self, sd):
196196
if "model.layers.47.self_attn.q_norm.weight" in sd:

0 commit comments

Comments
 (0)