Skip to content

Commit dc0c86b

Browse files
authored
Merge pull request #672 from lhotse-speech/feature/faster-speed-perturbation
~20x faster speed perturbation
2 parents ea9014e + 5e27e2c commit dc0c86b

File tree

2 files changed

+6
-12
lines changed

2 files changed

+6
-12
lines changed

lhotse/augmentation/torchaudio.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -114,16 +114,10 @@ class Speed(AudioTransform):
114114
factor: float
115115

116116
def __call__(self, samples: np.ndarray, sampling_rate: int) -> np.ndarray:
117-
check_torchaudio_version()
118-
import torchaudio
119-
120-
sampling_rate = int(sampling_rate) # paranoia mode
121-
effect = [["speed", str(self.factor)], ["rate", str(sampling_rate)]]
122-
if isinstance(samples, np.ndarray):
123-
samples = torch.from_numpy(samples)
124-
augmented, new_sampling_rate = torchaudio.sox_effects.apply_effects_tensor(
125-
samples, sampling_rate, effect
117+
resampler = get_or_create_resampler(
118+
round(sampling_rate * self.factor), sampling_rate
126119
)
120+
augmented = resampler(torch.from_numpy(samples))
127121
return augmented.numpy()
128122

129123
def reverse_timestamps(

test/augmentation/test_torchaudio.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def test_reverb_normalize_output(audio, rir, normalize_output, early_only):
127127
def test_speed(audio):
128128
speed = Speed(factor=1.1)
129129
perturbed = speed(audio, SAMPLING_RATE)
130-
assert perturbed.shape == (1, 14545)
130+
assert perturbed.shape == (1, 14546)
131131

132132

133133
@pytest.mark.parametrize("scale", [0.125, 1.0, 2.0])
@@ -143,7 +143,7 @@ def test_deserialize_transform_speed(audio):
143143
speed = AudioTransform.from_dict({"name": "Speed", "kwargs": {"factor": 1.1}})
144144
perturbed_speed = speed(audio, SAMPLING_RATE)
145145

146-
assert perturbed_speed.shape == (1, 14545)
146+
assert perturbed_speed.shape == (1, 14546)
147147

148148

149149
def test_deserialize_transform_volume(audio):
@@ -160,7 +160,7 @@ def test_serialize_deserialize_transform_speed(audio):
160160
speed = AudioTransform.from_dict(data_speed)
161161
perturbed_speed = speed(audio, SAMPLING_RATE)
162162

163-
assert perturbed_speed.shape == (1, 14545)
163+
assert perturbed_speed.shape == (1, 14546)
164164

165165

166166
def test_serialize_deserialize_transform_volume(audio):

0 commit comments

Comments
 (0)