Skip to content

Commit e18c227

Browse files
committed
simplify advanced demo
1 parent 335a8ba commit e18c227

File tree

1 file changed

+11
-56
lines changed

1 file changed

+11
-56
lines changed

run_advanced_GUI_demo.py

Lines changed: 11 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,10 @@
1515
from PyQt5.QtWidgets import QComboBox
1616
from PyQt5.QtWidgets import QFileDialog
1717
from PyQt5.QtWidgets import QHBoxLayout
18-
from PyQt5.QtWidgets import QLabel
1918
from PyQt5.QtWidgets import QLineEdit
2019
from PyQt5.QtWidgets import QMainWindow
2120
from PyQt5.QtWidgets import QMessageBox
2221
from PyQt5.QtWidgets import QPushButton
23-
from PyQt5.QtWidgets import QSlider
2422
from PyQt5.QtWidgets import QVBoxLayout
2523
from PyQt5.QtWidgets import QWidget
2624
from huggingface_hub import hf_hub_download
@@ -144,9 +142,6 @@ def __init__(self, tts_interface: ToucanTTSInterface):
144142
self.audio_file_path = None
145143
self.result_audio = None
146144
self.min_duration = 1
147-
self.slider_val = 100
148-
self.durations_are_scaled = False
149-
self.prev_slider_val_for_denorm = 100
150145

151146
self.setWindowTitle("TTS Model Interface")
152147
self.setGeometry(100, 100, 1200, 900)
@@ -177,7 +172,7 @@ def __init__(self, tts_interface: ToucanTTSInterface):
177172
# Initialize plots
178173
self.init_plots()
179174

180-
# Initialize slider and buttons
175+
# Initialize buttons
181176
self.init_controls()
182177

183178
# Initialize Timer for TTS Cooldown
@@ -189,10 +184,6 @@ def __init__(self, tts_interface: ToucanTTSInterface):
189184
def clear_all_widgets(self):
190185
self.spectrogram_view.setParent(None)
191186
self.pitch_plot.setParent(None)
192-
self.upper_row.setParent(None)
193-
self.slider_label.setParent(None)
194-
self.mod_slider.setParent(None)
195-
self.slider_value_label.setParent(None)
196187
self.generate_button.setParent(None)
197188
self.load_audio_button.setParent(None)
198189
self.save_audio_button.setParent(None)
@@ -218,6 +209,7 @@ def load_data(self, durations, pitch, spectrogram):
218209

219210
self.durations = durations
220211
self.cumulative_durations = np.cumsum(self.durations)
212+
self.pitch = pitch
221213
self.spectrogram = spectrogram
222214

223215
# Display Spectrogram
@@ -245,7 +237,7 @@ def load_data(self, durations, pitch, spectrogram):
245237
# Display Durations
246238
self.duration_lines = []
247239
for i, cum_dur in enumerate(self.cumulative_durations):
248-
line = pg.InfiniteLine(pos=cum_dur, angle=90, pen=pg.mkPen('orange', width=4))
240+
line = pg.InfiniteLine(pos=cum_dur, angle=90, pen=pg.mkPen('orange', width=2))
249241
self.spectrogram_view.addItem(line)
250242
line.setMovable(True)
251243
# Use lambda with default argument to capture current i
@@ -274,28 +266,6 @@ def init_controls(self):
274266
self.controls_layout = QVBoxLayout()
275267
self.main_layout.addLayout(self.controls_layout)
276268

277-
# Upper row layout for slider
278-
self.upper_row = QHBoxLayout()
279-
self.controls_layout.addLayout(self.upper_row)
280-
281-
# Slider Label
282-
self.slider_label = QLabel("Faster")
283-
self.upper_row.addWidget(self.slider_label)
284-
285-
# Slider
286-
self.mod_slider = QSlider(Qt.Horizontal)
287-
self.mod_slider.setMinimum(70)
288-
self.mod_slider.setMaximum(130)
289-
self.mod_slider.setValue(self.slider_val)
290-
self.mod_slider.setTickPosition(QSlider.TicksBelow)
291-
self.mod_slider.setTickInterval(10)
292-
self.mod_slider.valueChanged.connect(self.on_slider_changed)
293-
self.upper_row.addWidget(self.mod_slider)
294-
295-
# Slider Value Display
296-
self.slider_value_label = QLabel("Slower")
297-
self.upper_row.addWidget(self.slider_value_label)
298-
299269
# Lower row layout for buttons
300270
self.lower_row = QHBoxLayout()
301271
self.controls_layout.addLayout(self.lower_row)
@@ -406,18 +376,12 @@ def on_user_input_changed(self, text):
406376
# Mark that an update is required
407377
self.mark_tts_update()
408378

409-
def on_slider_changed(self, value):
410-
# Update the slider label
411-
# self.slider_value_label.setText(f"Durations at {value}%")
412-
self.slider_val = value
413-
# print(f"Slider changed to {scaling_factor * 100}% speed")
414-
# Mark that an update is required
415-
self.mark_tts_update()
416-
417379
def generate_new_prosody(self):
418380
"""
419381
Generate new prosody.
420382
"""
383+
if self.text_input.text().strip() == "":
384+
return
421385
wave, mel, durations, pitch = self.tts_backend(text=self.text_input.text(),
422386
view=False,
423387
duration_scaling_factor=1.0,
@@ -433,9 +397,6 @@ def generate_new_prosody(self):
433397
prosody_creativity=0.8,
434398
return_everything=True)
435399
# reset and clear everything
436-
self.slider_val = 100
437-
self.prev_slider_val_for_denorm = self.slider_val
438-
self.durations_are_scaled = False
439400
self.clear_all_widgets()
440401
self.init_plots()
441402
self.init_controls()
@@ -510,7 +471,8 @@ def save_audio_file(self):
510471

511472
def play_audio(self):
512473
# print("playing current audio...")
513-
sounddevice.play(self.result_audio, samplerate=24000)
474+
if self.result_audio is not None:
475+
sounddevice.play(self.result_audio, samplerate=24000)
514476

515477
def update_result_audio(self, audio_array):
516478
"""
@@ -525,7 +487,7 @@ def mark_tts_update(self):
525487
Marks that a TTS update is required and starts/resets the timer.
526488
"""
527489
self.tts_update_required = True
528-
self.tts_timer.start(600) # 600 milliseconds
490+
self.tts_timer.start(800) # 800 milliseconds delay before the model starts to compute something
529491

530492
def run_tts(self):
531493
"""
@@ -553,16 +515,12 @@ def run_tts(self):
553515
phonemes = self.tts_backend.text2phone.get_phone_string(text=text)
554516
self.phonemes = phonemes.replace(" ", "")
555517

556-
forced_durations = None if self.durations is None or len(self.durations) != len(self.phonemes) else insert_zeros_at_indexes(self.durations, self.word_boundaries)
557-
if forced_durations is not None and self.durations_are_scaled:
558-
forced_durations = torch.LongTensor([forced_duration / (self.prev_slider_val_for_denorm / 100) for forced_duration in forced_durations]).unsqueeze(0) # revert scaling
559-
elif forced_durations is not None:
560-
forced_durations = torch.LongTensor(forced_durations).unsqueeze(0)
518+
forced_durations = None if self.durations is None or len(self.durations) != len(self.phonemes) else torch.LongTensor(insert_zeros_at_indexes(self.durations, self.word_boundaries)).unsqueeze(0)
561519
forced_pitch = None if self.pitch is None or len(self.pitch) != len(self.phonemes) else torch.tensor(insert_zeros_at_indexes(self.pitch, self.word_boundaries)).unsqueeze(0)
562520

563521
wave, mel, durations, pitch = self.tts_backend(text,
564522
view=False,
565-
duration_scaling_factor=self.slider_val / 100,
523+
duration_scaling_factor=1.0,
566524
pitch_variance_scale=1.0,
567525
energy_variance_scale=1.0,
568526
pause_duration_scaling_factor=1.0,
@@ -576,9 +534,6 @@ def run_tts(self):
576534
return_everything=True)
577535

578536
self.word_boundaries = find_zero_indexes(durations)
579-
self.prev_slider_val_for_denorm = self.slider_val
580-
if self.slider_val != 100:
581-
self.durations_are_scaled = True
582537

583538
self.load_data(durations=durations.cpu().numpy(), pitch=pitch.cpu().numpy(), spectrogram=mel.cpu().transpose(0, 1).numpy())
584539

@@ -602,7 +557,7 @@ def main():
602557
}
603558
604559
QPushButton {
605-
background-color: #808000;
560+
background-color: #b9770e;
606561
border: 1px solid #ffffff;
607562
color: #ffffff;
608563
padding: 8px 16px;

0 commit comments

Comments
 (0)