waidhoferj's picture
Refactor config style and reorganize files
557fb53
raw
history blame
4.37 kB
import torch
import torchaudio
from torchaudio import transforms as taT, functional as taF
import torch.nn as nn
class WaveformTrainingPipeline(torch.nn.Module):
def __init__(
self,
input_freq=16000,
resample_freq=16000,
expected_duration=6,
snr_mean=6.0,
noise_path=None,
):
super().__init__()
self.input_freq = input_freq
self.snr_mean = snr_mean
self.noise = self.get_noise(noise_path)
self.resample_frequency = resample_freq
self.resample = taT.Resample(input_freq, resample_freq)
self.preprocess_waveform = WaveformPreprocessing(
resample_freq * expected_duration
)
def get_noise(self, path) -> torch.Tensor:
if path is None:
return None
noise, sr = torchaudio.load(path)
if noise.shape[0] > 1:
noise = noise.mean(0, keepdim=True)
if sr != self.input_freq:
noise = taF.resample(noise, sr, self.input_freq)
return noise
def add_noise(self, waveform: torch.Tensor) -> torch.Tensor:
assert (
self.noise is not None
), "Cannot add noise because a noise file was not provided."
num_repeats = waveform.shape[1] // self.noise.shape[1] + 1
noise = self.noise.repeat(1, num_repeats)[:, : waveform.shape[1]]
noise_power = noise.norm(p=2)
signal_power = waveform.norm(p=2)
snr_db = torch.normal(self.snr_mean, 1.5, (1,)).clamp_min(1.0)
snr = torch.exp(snr_db / 10)
scale = snr * noise_power / signal_power
noisy_waveform = (scale * waveform + noise) / 2
return noisy_waveform
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
waveform = self.resample(waveform)
waveform = self.preprocess_waveform(waveform)
if self.noise is not None:
waveform = self.add_noise(waveform)
return waveform
class SpectrogramTrainingPipeline(WaveformTrainingPipeline):
def __init__(
self, freq_mask_size=10, time_mask_size=80, mask_count=2, *args, **kwargs
):
super().__init__(*args, **kwargs)
self.mask_count = mask_count
self.audio_to_spectrogram = AudioToSpectrogram(
sample_rate=self.resample_frequency,
)
self.freq_mask = taT.FrequencyMasking(freq_mask_size)
self.time_mask = taT.TimeMasking(time_mask_size)
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
waveform = super().forward(waveform)
spec = self.audio_to_spectrogram(waveform)
# Spectrogram augmentation
for _ in range(self.mask_count):
spec = self.freq_mask(spec)
spec = self.time_mask(spec)
return spec
class WaveformPreprocessing(torch.nn.Module):
def __init__(self, expected_sample_length: int):
super().__init__()
self.expected_sample_length = expected_sample_length
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
# Take out extra channels
if waveform.shape[0] > 1:
waveform = waveform.mean(0, keepdim=True)
# ensure it is the correct length
waveform = self._rectify_duration(waveform)
return waveform
def _rectify_duration(self, waveform: torch.Tensor):
expected_samples = self.expected_sample_length
sample_count = waveform.shape[1]
if expected_samples == sample_count:
return waveform
elif expected_samples > sample_count:
pad_amount = expected_samples - sample_count
return torch.nn.functional.pad(
waveform, (0, pad_amount), mode="constant", value=0.0
)
else:
return waveform[:, :expected_samples]
class AudioToSpectrogram:
def __init__(
self,
sample_rate=16000,
):
self.spec = taT.MelSpectrogram(
sample_rate=sample_rate, n_mels=128, n_fft=1024
) # Note: this doesn't work on mps right now.
self.to_db = taT.AmplitudeToDB()
def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
spectrogram = self.spec(waveform)
spectrogram = self.to_db(spectrogram)
# Normalize
spectrogram = (spectrogram - spectrogram.mean()) / (2 * spectrogram.std())
return spectrogram