|
import os |
|
import random |
|
from multiprocessing import Manager |
|
from multiprocessing import Process |
|
|
|
import librosa |
|
import numpy |
|
import soundfile as sf |
|
import torch |
|
import torchaudio |
|
from torch.utils.data import Dataset |
|
from tqdm import tqdm |
|
|
|
from Preprocessing.AudioPreprocessor import AudioPreprocessor |
|
|
|
|
|
def random_pitch_shifter(x): |
|
n_steps = random.choice([-12, -9, -6, 3, 12]) |
|
return torchaudio.transforms.PitchShift(sample_rate=24000, n_steps=n_steps)(x) |
|
|
|
|
|
def polarity_inverter(x): |
|
return x * -1 |
|
|
|
|
|
class HiFiGANDataset(Dataset): |
|
|
|
def __init__(self, |
|
list_of_paths, |
|
desired_samplingrate=24000, |
|
samples_per_segment=12288, |
|
loading_processes=max(os.cpu_count() - 2, 1), |
|
use_random_corruption=False): |
|
self.use_random_corruption = use_random_corruption |
|
self.samples_per_segment = samples_per_segment |
|
self.desired_samplingrate = desired_samplingrate |
|
self.melspec_ap = AudioPreprocessor(input_sr=self.desired_samplingrate, |
|
output_sr=16000, |
|
cut_silence=False) |
|
|
|
|
|
if loading_processes == 1: |
|
self.waves = list() |
|
self.cache_builder_process(list_of_paths) |
|
else: |
|
resource_manager = Manager() |
|
self.waves = resource_manager.list() |
|
|
|
path_splits = list() |
|
process_list = list() |
|
for i in range(loading_processes): |
|
path_splits.append(list_of_paths[i * len(list_of_paths) // loading_processes:(i + 1) * len( |
|
list_of_paths) // loading_processes]) |
|
for path_split in path_splits: |
|
process_list.append(Process(target=self.cache_builder_process, args=(path_split,), daemon=True)) |
|
process_list[-1].start() |
|
for process in process_list: |
|
process.join() |
|
|
|
self.wave_augs = [random_pitch_shifter, polarity_inverter, lambda x: x, lambda x: x, lambda x: x, lambda x: x] |
|
self.wave_distortions = [CodecSimulator(), lambda x: x, lambda x: x, lambda x: x, lambda x: x] |
|
print("{} eligible audios found".format(len(self.waves))) |
|
|
|
def cache_builder_process(self, path_split): |
|
for path in tqdm(path_split): |
|
try: |
|
wave, sr = sf.read(path) |
|
if len(wave.shape) == 2: |
|
wave = librosa.to_mono(numpy.transpose(wave)) |
|
if sr != self.desired_samplingrate: |
|
wave = librosa.resample(y=wave, orig_sr=sr, target_sr=self.desired_samplingrate) |
|
|
|
self.waves.append(wave) |
|
except RuntimeError: |
|
print(f"Problem with the following path: {path}") |
|
|
|
def __getitem__(self, index): |
|
""" |
|
load the audio from the path and clean it. |
|
All audio segments have to be cut to the same length, |
|
according to the NeurIPS reference implementation. |
|
|
|
return a pair of high-res audio and corresponding low-res spectrogram as if it was predicted by the TTS |
|
""" |
|
try: |
|
wave = self.waves[index] |
|
while len(wave) < self.samples_per_segment + 50: |
|
|
|
wave = numpy.concatenate([wave, numpy.zeros(shape=1000), wave]) |
|
|
|
wave = torch.Tensor(wave) |
|
|
|
if self.use_random_corruption: |
|
|
|
wave = random.choice(self.wave_augs)(wave.unsqueeze(0)).squeeze(0) |
|
|
|
max_audio_start = len(wave) - self.samples_per_segment |
|
audio_start = random.randint(0, max_audio_start) |
|
segment = wave[audio_start: audio_start + self.samples_per_segment] |
|
|
|
resampled_segment = self.melspec_ap.resample(segment).float() |
|
if self.use_random_corruption: |
|
|
|
resampled_segment = random.choice(self.wave_distortions)(resampled_segment.unsqueeze(0)).squeeze(0) |
|
melspec = self.melspec_ap.audio_to_mel_spec_tensor(resampled_segment, |
|
explicit_sampling_rate=16000, |
|
normalize=False).transpose(0, 1)[:-1].transpose(0, 1) |
|
return segment.detach(), melspec.detach() |
|
except RuntimeError: |
|
print("encountered a runtime error, using fallback strategy") |
|
if index == 0: |
|
index = len(self.waves) - 1 |
|
return self.__getitem__(index - 1) |
|
|
|
def __len__(self): |
|
return len(self.waves) |
|
|
|
|
|
class CodecSimulator(torch.nn.Module): |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self.encoder = torchaudio.transforms.MuLawEncoding(quantization_channels=64) |
|
self.decoder = torchaudio.transforms.MuLawDecoding(quantization_channels=64) |
|
|
|
def forward(self, x): |
|
return self.decoder(self.encoder(x)) |
|
|
|
|
|
if __name__ == '__main__': |
|
import matplotlib.pyplot as plt |
|
|
|
wav, sr = sf.read("../../audios/speaker_references/female_high_voice.wav") |
|
resampled_wave = torch.Tensor(librosa.resample(y=wav, orig_sr=sr, target_sr=24000)) |
|
audio = torch.tensor(resampled_wave) |
|
melspec_ap = AudioPreprocessor(input_sr=24000, |
|
output_sr=16000, |
|
cut_silence=False) |
|
|
|
spec = melspec_ap.audio_to_mel_spec_tensor(melspec_ap.resample(resampled_wave).float(), |
|
explicit_sampling_rate=16000, |
|
normalize=False).transpose(0, 1)[:-1].transpose(0, 1) |
|
|
|
cs = CodecSimulator() |
|
masker = torchaudio.transforms.FrequencyMasking(freq_mask_param=16, iid_masks=True) |
|
|
|
|
|
out = cs(resampled_wave.unsqueeze(0)).squeeze(0) |
|
|
|
plt.plot(resampled_wave, alpha=0.5) |
|
plt.plot(out, alpha=0.5) |
|
plt.title("Codec Simulator") |
|
plt.show() |
|
|
|
|
|
for _ in range(5): |
|
masked_spec = masker(spec.unsqueeze(0)).squeeze(0) |
|
print(masked_spec) |
|
plt.imshow(masked_spec.cpu().numpy(), origin="lower", cmap='GnBu') |
|
plt.title("Masked Spec") |
|
plt.show() |
|
|
|
|
|
for _ in range(5): |
|
shifted_wave = random_pitch_shifter(resampled_wave.unsqueeze(0)).squeeze(0) |
|
shifted_spec = melspec_ap.audio_to_mel_spec_tensor(melspec_ap.resample(shifted_wave).float(), |
|
explicit_sampling_rate=16000, |
|
normalize=False).transpose(0, 1)[:-1].transpose(0, 1) |
|
plt.imshow(shifted_spec.detach().cpu().numpy(), origin="lower", cmap='GnBu') |
|
plt.title("Pitch Shifted Spec") |
|
plt.show() |
|
|