Skip to content

Commit 36dfc13

Browse files
fix gpu bugs, add test cases
1 parent c68bb43 commit 36dfc13

File tree

11 files changed

+100
-59
lines changed

11 files changed

+100
-59
lines changed

README.md

Lines changed: 28 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -47,47 +47,39 @@ streamlit run test/streamlit.py
4747
**Important:** When you run the above command for the first time, the web page may leave blank for several minutes for downloading models. You can checkout the terminal for downloading progresses.
4848

4949

50-
### Python interface
50+
### Example
5151

52-
Basic examples:
52+
- *test/test.py*:
5353

5454
```python
55-
# Will automatically download model parameters.
56-
from voicefixer import VoiceFixer
57-
from voicefixer import Vocoder
55+
...
5856

59-
# Initialize model
57+
# TEST VOICEFIXER
58+
## Initialize a voicefixer
6059
voicefixer = VoiceFixer()
61-
# Speech restoration
62-
63-
# Mode 0: Original Model (suggested by default)
64-
voicefixer.restore(input="", # input wav file path
65-
output="", # output wav file path
66-
cuda=False, # whether to use gpu acceleration
67-
mode = 0) # You can try out mode 0, 1, 2 to find out the best result
68-
# Mode 1: Add preprocessing module (remove higher frequency)
69-
voicefixer.restore(input="", # input wav file path
70-
output="", # output wav file path
71-
cuda=False, # whether to use gpu acceleration
72-
mode = 1) # You can try out mode 0, 1, 2 to find out the best result
73-
# Mode 2: Train mode (might work sometimes on seriously degraded real speech)
74-
voicefixer.restore(input="", # input wav file path
75-
output="", # output wav file path
76-
cuda=False, # whether to use gpu acceleration
77-
mode = 2) # You can try out mode 0, 1, 2 to find out the best result
78-
79-
# Another similar function
80-
# voicefixer.restore_inmem()
81-
82-
# Universal speaker independent vocoder
83-
vocoder = Vocoder(sample_rate=44100) # Only 44100 sampling rate is supported.
84-
85-
# Convert mel spectrogram to waveform
86-
wave = vocoder.forward(mel=mel_spec) # This forward function is used in the following oracle function.
87-
88-
# Test vocoder using the mel spectrogram of 'fpath', save output to file out_path
89-
vocoder.oracle(fpath="", # input wav file path
90-
out_path="") # output wav file path
60+
## Mode 0: Original Model (suggested by default)
61+
## Mode 1: Add preprocessing module (remove higher frequency)
62+
## Mode 2: Train mode (might work sometimes on seriously degraded real speech)
63+
for mode in [0,1,2]:
64+
voicefixer.restore(input=os.path.join(git_root,"test/utterance/original/original.flac"), # low quality .wav/.flac file
65+
output=os.path.join(git_root,"test/utterance/output/output_mode_"+str(mode)+".flac"), # save file path
66+
cuda=False, # GPU acceleration
67+
mode=mode)
68+
69+
70+
# TEST VOCODER
71+
## Initialize a vocoder. Only 44100 sampling rate is supported.
72+
vocoder = Vocoder(sample_rate=44100)
73+
74+
### read wave (fpath) -> mel spectrogram -> vocoder -> wave -> save wave (out_path)
75+
vocoder.oracle(fpath=os.path.join(git_root,"test/utterance/original/original.flac"),
76+
out_path=os.path.join(git_root,"test/utterance/output/oracle.flac"),
77+
cuda=False) # GPU acceleration
78+
79+
# Other interfaces
80+
# voicefixer.restore_inmem
81+
# vocoder.forward
82+
...
9183
```
9284

9385
### Others Features

test/test.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,53 @@
1010
9/14/21 11:02 AM Haohe Liu 1.0 None
1111
'''
1212

13-
from voicefixer import VoiceFixer
14-
13+
import git
14+
import os
15+
import librosa
16+
import numpy as np
17+
from voicefixer import VoiceFixer, Vocoder
18+
19+
git_root = git.Repo("", search_parent_directories=True).git.rev_parse("--show-toplevel")
20+
os.makedirs(os.path.join(git_root,"test/utterance/output"),exist_ok=True)
21+
22+
def check(fname):
23+
"""
24+
check if the output is normal
25+
"""
26+
output = os.path.join(git_root,"test/utterance/output",fname)
27+
target = os.path.join(git_root, "test/utterance/target", fname)
28+
output, _ = librosa.load(output,sr=44100)
29+
target, _ = librosa.load(target, sr=44100)
30+
assert np.mean(np.abs(output-target)) < 0.01
31+
32+
# TEST VOICEFIXER
33+
## Initialize a voicefixer
1534
voicefixer = VoiceFixer()
35+
# Mode 0: Original Model (suggested by default)
36+
# Mode 1: Add preprocessing module (remove higher frequency)
37+
# Mode 2: Train mode (might work sometimes on seriously degraded real speech)
38+
for mode in [0,1,2]:
39+
voicefixer.restore(input=os.path.join(git_root,"test/utterance/original/original.flac"), # low quality .wav/.flac file
40+
output=os.path.join(git_root,"test/utterance/output/output_mode_"+str(mode)+".flac"), # save file path
41+
cuda=False, # GPU acceleration
42+
mode=mode)
43+
if(mode != 2):
44+
check("output_mode_"+str(mode)+".flac")
45+
46+
47+
# TEST VOCODER
48+
## Initialize a vocoder
49+
vocoder = Vocoder(sample_rate=44100)
1650

17-
voicefixer.restore(input="/Users/liuhaohe/Downloads/vocals.wav",
18-
output="/Users/liuhaohe/Downloads/vocals_mode_0.wav",
19-
cuda=False,mode=0)
51+
### read wave (fpath) -> mel spectrogram -> vocoder -> wave -> save wave (out_path)
52+
vocoder.oracle(fpath=os.path.join(git_root,"test/utterance/original/original.flac"),
53+
out_path=os.path.join(git_root,"test/utterance/output/oracle.flac"),
54+
cuda=False) # GPU acceleration
2055

56+
# Another interface
57+
# vocoder.forward(mel=mel)
2158

22-
voicefixer.restore(input="/Users/liuhaohe/Downloads/vocals.wav",
23-
output="/Users/liuhaohe/Downloads/vocals_mode_1.wav",
24-
cuda=False,mode=1)
59+
check("oracle.flac")
2560

61+
print("Pass")
2662

27-
voicefixer.restore(input="/Users/liuhaohe/Downloads/vocals.wav",
28-
output="/Users/liuhaohe/Downloads/vocals_mode_2.wav",
29-
cuda=False,mode=2)
124 KB
Binary file not shown.

test/utterance/target/oracle.flac

150 KB
Binary file not shown.
134 KB
Binary file not shown.
124 KB
Binary file not shown.
143 KB
Binary file not shown.

voicefixer/base.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,23 +76,20 @@ def remove_higher_frequency(self, wav, ratio=0.95):
7676
while (i < energy_level.shape[0] and curent_level < threshold):
7777
curent_level += energy_level[i + 1, ...]
7878
i += 1
79-
print(i)
8079
spec[i:, ...] = np.zeros_like(spec[i:, ...])
8180
stft = spec * cos + 1j * spec * sin
8281
return librosa.istft(stft)
8382

8483
@torch.no_grad()
8584
def restore_inmem(self, wav_10k, cuda=False, mode=0, your_vocoder_func=None):
86-
if(cuda and torch.cuda.is_available()):
87-
self._model = self._model.cuda()
88-
# metrics = {}
85+
check_cuda_availability(cuda=cuda)
86+
try_tensor_cuda(self._model,cuda=cuda)
8987
if(mode == 0):
9088
self._model.eval()
9189
elif(mode == 1):
9290
self._model.eval()
9391
elif(mode == 2):
9492
self._model.train() # More effective on seriously demaged speech
95-
9693
res = []
9794
seg_length = 44100*30
9895
break_point = seg_length

voicefixer/tools/pytorch_util.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,17 @@
22
import torch.nn as nn
33
import numpy as np
44

5+
6+
def check_cuda_availability(cuda):
7+
if(cuda and not torch.cuda.is_available()):
8+
raise RuntimeError("Error: You set cuda=True but no cuda device found.")
9+
10+
def try_tensor_cuda(tensor, cuda):
11+
if(cuda and torch.cuda.is_available()):
12+
return tensor.cuda()
13+
else:
14+
return tensor.cpu()
15+
516
def to_log(input):
617
assert torch.sum(input < 0) == 0, str(input)+" has negative values counts "+str(torch.sum(input < 0))
718
return torch.log10(torch.clip(input, min=1e-8))

voicefixer/vocoder/base.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,26 +32,31 @@ def _load_pretrain(self, pth):
3232
# wav_re = self.model(mel) # torch.Size([1, 1, 104076])
3333
# save_wave(tensor2numpy(wav_re)*2**15,save_dir,sample_rate=sample_rate)
3434

35-
def forward(self,mel):
35+
def forward(self,mel, cuda=False):
3636
"""
3737
:param non normalized mel spectrogram: [batchsize, 1, t-steps, n_mel]
3838
:return: [batchsize, 1, samples]
3939
"""
4040
assert mel.size()[-1] == 128
41+
check_cuda_availability(cuda=cuda)
42+
try_tensor_cuda(self.model,cuda=cuda)
43+
try_tensor_cuda(mel,cuda=cuda)
4144
self.weight_torch = self.weight_torch.type_as(mel)
4245
mel = mel / self.weight_torch
4346
mel = tr_normalize(tr_amp_to_db(torch.abs(mel)) - 20.0)
4447
mel = tr_pre(mel[:,0,...])
4548
wav_re = self.model(mel)
4649
return wav_re
4750

48-
def oracle(self, fpath, out_path):
51+
def oracle(self, fpath, out_path, cuda=False):
52+
check_cuda_availability(cuda=cuda)
53+
try_tensor_cuda(self.model,cuda=cuda)
4954
wav = read_wave(fpath, sample_rate=self.rate)[..., 0]
5055
wav = wav/np.max(np.abs(wav))
5156
stft = np.abs(librosa.stft(wav,hop_length=Config.hop_length,win_length=Config.win_size,n_fft=Config.n_fft))
5257
mel = linear_to_mel(stft)
5358
mel = normalize(amp_to_db(np.abs(mel)) - 20)
54-
mel = pre(np.transpose(mel, (1, 0)))
59+
mel = pre(np.transpose(mel, (1, 0)),cuda=cuda)
5560
with torch.no_grad():
5661
wav_re = self.model(mel)
5762
save_wave(tensor2numpy(wav_re*2**15), out_path, sample_rate=self.rate)

0 commit comments

Comments
 (0)