-
Notifications
You must be signed in to change notification settings - Fork 3.1k
End detection updates and Text context remapping during training #14569
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: magpietts_2508
Are you sure you want to change the base?
End detection updates and Text context remapping during training #14569
Conversation
Signed-off-by: Paarth Neekhara <[email protected]>
Signed-off-by: Paarth Neekhara <[email protected]>
Signed-off-by: Paarth Neekhara <[email protected]>
Signed-off-by: Paarth Neekhara <[email protected]>
Signed-off-by: paarthneekhara <[email protected]>
Regarding EOS detection: that's an interesting that detecting from all codebooks can lead to early termination -- good find! About the option detecting from codebook 0 only: there is the issue of what happens if an EOS appears only in a codebook other than 0 and we ignore it (for EOS detection purposes). Then we'd end up replacing the token with token ID 0 before decoding with the codec, but we know 0 doesn't necessarily correspond to codec silence. But maybe we need to consider these two mechanisms separately:
If (2) works better that'll give (1) more flexibility. |
Cool addition of the text remapping logic, by the way! |
That's a good point. We should figure out a better way to clean up the codes. Are you suggesting if we want to clamp a partucular codebook to be within the range, we should make the whole frame silent or just that codebook. |
I think it probably requires experimentation, e.g. by listening to find out what sounds better of the options we're considering:
One way to choose would be to load some codes of a real speech signal and randomly choose a few positions (frame+codebooks combinations) that we will corrupt with the 3 methods above; listen to all 3 and choose the one that sounds best. |
An alternative that came to mind: if we detected an EOS which we want to ignore, e.g. because it's not in codebook 0, and we need to replace with something: resample. It wouldn't require another forward pass and should give something plausible (possibly better than silence). But we'd have to make sure not to get stuck in an infinite resampling loop (could limit how many times we reasample, then accept EOS if it insists...). Edit: On second thought, the resampling option only works for parallel prediction. If we have a local transformer we'd have to do something else, like either run the LT again (which is higher complexity than resampling), replace with silence, or maybe just resample the particular codebook from the parallel head. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minor style changes
@@ -1997,7 +2013,7 @@ def get_inference_attention_plots( | |||
|
|||
return cross_attention_maps, headwise_cross_attention_maps | |||
|
|||
def find_eos_frame_index(self, codes) -> Optional[int]: | |||
def find_eos_frame_index(self, codes, eos_detection_method) -> Optional[int]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you changed the default return of this function, can you update the typehint and docstring to more accurately match the updated function?
return None | ||
return float('inf') | ||
|
||
def detect_eos(self, audio_codes_multinomial, audio_codes_argmax, eos_detection_method): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add typehints and a docstring to this function?
): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add a docstring to infer_batch since there any many arguments in this function now.
We should update the infer_batch calls in magpietts_preference_optimization.py to use the right arguments for end detection logic during GRPO/DPO data generation. |
This PR introduces two main updates.
Added an option to disable finished/unfinished sentence tracking when prior is applied. Sometimes I notice forcing or disallowing EOS prediction introduces artifacts towards the end, unless we really finetune the unifinished and finished sentences constants. I am seeing we don't really need to handhold whether the sentence is finshed or not and can trust the model's decision, even when we apply the prior. So I have added an argument to not force or disallow EOS prediction by setting ignore_finished_sentence_tracking=True.
EOS detection - Right now, the logic was that if any codebook in multinomial or argmax sampling has an EOS token predicted, we predict the end. VERY RARELY, I notice this leads to predicting EOS abruptly. At some point, our logic was that if argmax of codebook 0 is EOS, we predict the end. I have added a few options - to predict EOS is any, all or zeroth codebook is EOS. And whether to only look at argmax sampling, or either of argmax or multinomial sampling. Keeping this customizable, because I suspect that parallel prediction EOS logic might be different from MaskGIT EOS prediction logic.