@@ -61,19 +61,41 @@ def _pre(self, model, input, cuda):
6161 # return models.to_log(sp), models.to_log(mel_orig)
6262 return sp , mel_orig
6363
64+ def remove_higher_frequency (self , wav , ratio = 0.95 ):
65+ stft = librosa .stft (wav )
66+ real , img = np .real (stft ), np .imag (stft )
67+ mag = (real ** 2 + img ** 2 ) ** 0.5
68+ cos , sin = real / mag , img / mag
69+ spec = np .abs (stft ) # [1025,T]
70+ feature = spec .copy ()
71+ feature = np .log10 (feature )
72+ feature [feature < 0 ] = 0
73+ energy_level = np .sum (feature , axis = 1 )
74+ threshold = np .sum (energy_level ) * ratio
75+ curent_level , i = energy_level [0 ], 0
76+ while (i < energy_level .shape [0 ] and curent_level < threshold ):
77+ curent_level += energy_level [i + 1 , ...]
78+ i += 1
79+ spec [i :, ...] = np .zeros_like (spec [i :, ...])
80+ stft = spec * cos + 1j * spec * sin
81+ return librosa .istft (stft )
82+
6483 def restore (self , input , output , cuda = False , mode = 0 ):
6584 if (cuda and torch .cuda .is_available ()):
6685 self ._model = self ._model .cuda ()
6786 # metrics = {}
68- if (mode == 1 ):
69- self ._model .train () # More effective on seriously demaged speech
70- elif (mode == 2 ):
71- self ._model .generator .denoiser .train () # Another option worth trying
72- else :
87+ if (mode == 0 ):
7388 self ._model .eval ()
89+ elif (mode == 1 ):
90+ self ._model .eval ()
91+ elif (mode == 2 ):
92+ self ._model .train () # More effective on seriously demaged speech
7493
7594 with torch .no_grad ():
7695 wav_10k = self ._load_wav (input , sample_rate = 44100 )
96+ if (mode == 0 ):
97+ # print("In mode 0, we will remove part of the higher frequency part before processing")
98+ wav_10k = self .remove_higher_frequency (wav_10k )
7799 res = []
78100 seg_length = 44100 * 60
79101 break_point = seg_length
0 commit comments