diff --git a/dac/model/base.py b/dac/model/base.py index 546b3cb..a6c93f9 100644 --- a/dac/model/base.py +++ b/dac/model/base.py @@ -204,8 +204,8 @@ def compress( range_fn = range if not verbose else tqdm.trange for i in range_fn(0, nt, hop): - x = audio_signal[..., i : i + n_samples] - x = x.zero_pad(0, max(0, n_samples - x.shape[-1])) + x = audio_signal[..., i : i + n_samples + self.hop_length] # crossfade pad + x = x.zero_pad(0, max(0, n_samples - x.shape[-1] + self.hop_length)) audio_data = x.audio_data.to(self.device) audio_data = self.preprocess(audio_data, self.sample_rate) @@ -271,8 +271,13 @@ def decompress( r = self.decode(z) recons.append(r.to(original_device)) - recons = torch.cat(recons, dim=-1) - recons = AudioSignal(recons, self.sample_rate) + # recons = torch.cat(recons, dim=-1) + # recons = AudioSignal(recons, self.sample_rate) + + chunks = recons[0] + for i in range(1, len(recons)): + chunks = self.crossfade_concat(chunks, recons[i], self.hop_length) + recons = AudioSignal(chunks, self.sample_rate) resample_fn = recons.resample loudness_fn = recons.loudness @@ -292,3 +297,11 @@ def decompress( self.padding = original_padding return recons + + @torch.no_grad() + def crossfade_concat(self, chunk1, chunk2, overlap): + fade_out = torch.cos(torch.linspace(0, torch.pi / 2, overlap)) ** 2 + fade_in = torch.cos(torch.linspace(torch.pi / 2, 0, overlap)) ** 2 + chunk2[..., :overlap] = chunk1[..., -overlap:] * fade_out + chunk2[..., :overlap] * fade_in + chunk = torch.cat((chunk1[..., :-overlap], chunk2), dim=-1) + return chunk