-
Notifications
You must be signed in to change notification settings - Fork 69
Description
Hi , thanks for your amazing work.
I'm exploring the TangoFlux
codebase, and I have a question about how the unconditional text embedding is handled during training and inference for Classifier-Free Guidance (CFG).
In the forward
method (presumably for training, when sft=True
and self.uncondition=True
), it appears that unconditional text input is simulated by setting the encoder_hidden_states
for some samples to zero:
# In forward method (training)
if self.uncondition:
mask_indices = [k for k in range(len(prompt)) if random.random() < 0.1]
if len(mask_indices) > 0:
encoder_hidden_states[mask_indices] = 0 # Setting embeddings to zero
However, in the inference_flow
method (for CFG, when guidance_scale > 1.0
), the unconditional text embedding is obtained by explicitly encoding an empty string ""
using self.encode_text_classifier_free
:
# In inference_flow method (inference)
uncond_tokens = [""]
# ...
negative_prompt_embeds = self.text_encoder(
input_ids=uncond_input_ids, attention_mask=uncond_attention_mask
)[0] # Encoding ""
# ...
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # Concatenating unconditional and conditional
My understanding is that setting the embeddings to zero (torch.zeros_like
) and encoding an empty string (""
) will generally result in numerically different unconditional vectors for the text encoder.
Could you please explain the reasoning behind using these two different approaches for handling the unconditional text input during training (simulated by zeroing) and inference (by encoding ""
)?
Is there a specific advantage to using the encoded empty string for inference CFG compared to using a zero vector, given that zero vectors are used to simulate unconditionality during training? Would aligning them (e.g., using zero vectors for inference as well) be expected to work, or is the current setup intentional for better performance?
Thank you for clarifying!
Best regards