1515from PyQt5 .QtWidgets import QComboBox
1616from PyQt5 .QtWidgets import QFileDialog
1717from PyQt5 .QtWidgets import QHBoxLayout
18- from PyQt5 .QtWidgets import QLabel
1918from PyQt5 .QtWidgets import QLineEdit
2019from PyQt5 .QtWidgets import QMainWindow
2120from PyQt5 .QtWidgets import QMessageBox
2221from PyQt5 .QtWidgets import QPushButton
23- from PyQt5 .QtWidgets import QSlider
2422from PyQt5 .QtWidgets import QVBoxLayout
2523from PyQt5 .QtWidgets import QWidget
2624from huggingface_hub import hf_hub_download
@@ -144,9 +142,6 @@ def __init__(self, tts_interface: ToucanTTSInterface):
144142 self .audio_file_path = None
145143 self .result_audio = None
146144 self .min_duration = 1
147- self .slider_val = 100
148- self .durations_are_scaled = False
149- self .prev_slider_val_for_denorm = 100
150145
151146 self .setWindowTitle ("TTS Model Interface" )
152147 self .setGeometry (100 , 100 , 1200 , 900 )
@@ -177,7 +172,7 @@ def __init__(self, tts_interface: ToucanTTSInterface):
177172 # Initialize plots
178173 self .init_plots ()
179174
180- # Initialize slider and buttons
175+ # Initialize buttons
181176 self .init_controls ()
182177
183178 # Initialize Timer for TTS Cooldown
@@ -189,10 +184,6 @@ def __init__(self, tts_interface: ToucanTTSInterface):
189184 def clear_all_widgets (self ):
190185 self .spectrogram_view .setParent (None )
191186 self .pitch_plot .setParent (None )
192- self .upper_row .setParent (None )
193- self .slider_label .setParent (None )
194- self .mod_slider .setParent (None )
195- self .slider_value_label .setParent (None )
196187 self .generate_button .setParent (None )
197188 self .load_audio_button .setParent (None )
198189 self .save_audio_button .setParent (None )
@@ -218,6 +209,7 @@ def load_data(self, durations, pitch, spectrogram):
218209
219210 self .durations = durations
220211 self .cumulative_durations = np .cumsum (self .durations )
212+ self .pitch = pitch
221213 self .spectrogram = spectrogram
222214
223215 # Display Spectrogram
@@ -245,7 +237,7 @@ def load_data(self, durations, pitch, spectrogram):
245237 # Display Durations
246238 self .duration_lines = []
247239 for i , cum_dur in enumerate (self .cumulative_durations ):
248- line = pg .InfiniteLine (pos = cum_dur , angle = 90 , pen = pg .mkPen ('orange' , width = 4 ))
240+ line = pg .InfiniteLine (pos = cum_dur , angle = 90 , pen = pg .mkPen ('orange' , width = 2 ))
249241 self .spectrogram_view .addItem (line )
250242 line .setMovable (True )
251243 # Use lambda with default argument to capture current i
@@ -274,28 +266,6 @@ def init_controls(self):
274266 self .controls_layout = QVBoxLayout ()
275267 self .main_layout .addLayout (self .controls_layout )
276268
277- # Upper row layout for slider
278- self .upper_row = QHBoxLayout ()
279- self .controls_layout .addLayout (self .upper_row )
280-
281- # Slider Label
282- self .slider_label = QLabel ("Faster" )
283- self .upper_row .addWidget (self .slider_label )
284-
285- # Slider
286- self .mod_slider = QSlider (Qt .Horizontal )
287- self .mod_slider .setMinimum (70 )
288- self .mod_slider .setMaximum (130 )
289- self .mod_slider .setValue (self .slider_val )
290- self .mod_slider .setTickPosition (QSlider .TicksBelow )
291- self .mod_slider .setTickInterval (10 )
292- self .mod_slider .valueChanged .connect (self .on_slider_changed )
293- self .upper_row .addWidget (self .mod_slider )
294-
295- # Slider Value Display
296- self .slider_value_label = QLabel ("Slower" )
297- self .upper_row .addWidget (self .slider_value_label )
298-
299269 # Lower row layout for buttons
300270 self .lower_row = QHBoxLayout ()
301271 self .controls_layout .addLayout (self .lower_row )
@@ -406,18 +376,12 @@ def on_user_input_changed(self, text):
406376 # Mark that an update is required
407377 self .mark_tts_update ()
408378
409- def on_slider_changed (self , value ):
410- # Update the slider label
411- # self.slider_value_label.setText(f"Durations at {value}%")
412- self .slider_val = value
413- # print(f"Slider changed to {scaling_factor * 100}% speed")
414- # Mark that an update is required
415- self .mark_tts_update ()
416-
417379 def generate_new_prosody (self ):
418380 """
419381 Generate new prosody.
420382 """
383+ if self .text_input .text ().strip () == "" :
384+ return
421385 wave , mel , durations , pitch = self .tts_backend (text = self .text_input .text (),
422386 view = False ,
423387 duration_scaling_factor = 1.0 ,
@@ -433,9 +397,6 @@ def generate_new_prosody(self):
433397 prosody_creativity = 0.8 ,
434398 return_everything = True )
435399 # reset and clear everything
436- self .slider_val = 100
437- self .prev_slider_val_for_denorm = self .slider_val
438- self .durations_are_scaled = False
439400 self .clear_all_widgets ()
440401 self .init_plots ()
441402 self .init_controls ()
@@ -510,7 +471,8 @@ def save_audio_file(self):
510471
511472 def play_audio (self ):
512473 # print("playing current audio...")
513- sounddevice .play (self .result_audio , samplerate = 24000 )
474+ if self .result_audio is not None :
475+ sounddevice .play (self .result_audio , samplerate = 24000 )
514476
515477 def update_result_audio (self , audio_array ):
516478 """
@@ -525,7 +487,7 @@ def mark_tts_update(self):
525487 Marks that a TTS update is required and starts/resets the timer.
526488 """
527489 self .tts_update_required = True
528- self .tts_timer .start (600 ) # 600 milliseconds
490+ self .tts_timer .start (800 ) # 800 milliseconds delay before the model starts to compute something
529491
530492 def run_tts (self ):
531493 """
@@ -553,16 +515,12 @@ def run_tts(self):
553515 phonemes = self .tts_backend .text2phone .get_phone_string (text = text )
554516 self .phonemes = phonemes .replace (" " , "" )
555517
556- forced_durations = None if self .durations is None or len (self .durations ) != len (self .phonemes ) else insert_zeros_at_indexes (self .durations , self .word_boundaries )
557- if forced_durations is not None and self .durations_are_scaled :
558- forced_durations = torch .LongTensor ([forced_duration / (self .prev_slider_val_for_denorm / 100 ) for forced_duration in forced_durations ]).unsqueeze (0 ) # revert scaling
559- elif forced_durations is not None :
560- forced_durations = torch .LongTensor (forced_durations ).unsqueeze (0 )
518+ forced_durations = None if self .durations is None or len (self .durations ) != len (self .phonemes ) else torch .LongTensor (insert_zeros_at_indexes (self .durations , self .word_boundaries )).unsqueeze (0 )
561519 forced_pitch = None if self .pitch is None or len (self .pitch ) != len (self .phonemes ) else torch .tensor (insert_zeros_at_indexes (self .pitch , self .word_boundaries )).unsqueeze (0 )
562520
563521 wave , mel , durations , pitch = self .tts_backend (text ,
564522 view = False ,
565- duration_scaling_factor = self . slider_val / 100 ,
523+ duration_scaling_factor = 1.0 ,
566524 pitch_variance_scale = 1.0 ,
567525 energy_variance_scale = 1.0 ,
568526 pause_duration_scaling_factor = 1.0 ,
@@ -576,9 +534,6 @@ def run_tts(self):
576534 return_everything = True )
577535
578536 self .word_boundaries = find_zero_indexes (durations )
579- self .prev_slider_val_for_denorm = self .slider_val
580- if self .slider_val != 100 :
581- self .durations_are_scaled = True
582537
583538 self .load_data (durations = durations .cpu ().numpy (), pitch = pitch .cpu ().numpy (), spectrogram = mel .cpu ().transpose (0 , 1 ).numpy ())
584539
@@ -602,7 +557,7 @@ def main():
602557 }
603558
604559 QPushButton {
605- background-color: #808000 ;
560+ background-color: #b9770e ;
606561 border: 1px solid #ffffff;
607562 color: #ffffff;
608563 padding: 8px 16px;
0 commit comments