waidhoferj's picture
updated production build to use multiple overlapping samples
51f4763
raw
history blame
4.79 kB
import torch
import torchaudio
from torchaudio import transforms as taT, functional as taF
import torch.nn as nn
class WaveformTrainingPipeline(torch.nn.Module):
def __init__(
self,
expected_duration=6,
snr_mean=6.0,
noise_path=None,
):
super().__init__()
self.snr_mean = snr_mean
self.noise = self.get_noise(noise_path)
self.sample_rate = 16000
self.preprocess_waveform = WaveformPreprocessing(
self.sample_rate * expected_duration
)
def get_noise(self, path) -> torch.Tensor:
if path is None:
return None
noise, sr = torchaudio.load(path)
if noise.shape[0] > 1:
noise = noise.mean(0, keepdim=True)
if sr != self.sample_rate:
noise = taF.resample(noise, sr, self.sample_rate)
return noise
def add_noise(self, waveform: torch.Tensor) -> torch.Tensor:
assert (
self.noise is not None
), "Cannot add noise because a noise file was not provided."
num_repeats = waveform.shape[1] // self.noise.shape[1] + 1
noise = self.noise.repeat(1, num_repeats)[:, : waveform.shape[1]]
noise_power = noise.norm(p=2)
signal_power = waveform.norm(p=2)
snr_db = torch.normal(self.snr_mean, 1.5, (1,)).clamp_min(1.0)
snr = torch.exp(snr_db / 10)
scale = snr * noise_power / signal_power
noisy_waveform = (scale * waveform + noise) / 2
return noisy_waveform
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
waveform = self.preprocess_waveform(waveform)
if self.noise is not None:
waveform = self.add_noise(waveform)
return waveform
class SpectrogramTrainingPipeline(WaveformTrainingPipeline):
def __init__(
self, freq_mask_size=10, time_mask_size=80, mask_count=2, *args, **kwargs
):
super().__init__(*args, **kwargs)
self.mask_count = mask_count
self.audio_to_spectrogram = AudioToSpectrogram(
sample_rate=self.sample_rate,
)
self.freq_mask = taT.FrequencyMasking(freq_mask_size)
self.time_mask = taT.TimeMasking(time_mask_size)
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
waveform = super().forward(waveform)
spec = self.audio_to_spectrogram(waveform)
# Spectrogram augmentation
for _ in range(self.mask_count):
spec = self.freq_mask(spec)
spec = self.time_mask(spec)
return spec
class SpectrogramProductionPipeline(torch.nn.Module):
def __init__(self, sample_rate=16000, expected_duration=6, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.preprocess_waveform = WaveformPreprocessing(
sample_rate * expected_duration
)
self.audio_to_spectrogram = AudioToSpectrogram(
sample_rate=sample_rate,
)
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
waveform = self.preprocess_waveform(waveform)
return self.audio_to_spectrogram(waveform)
class WaveformPreprocessing(torch.nn.Module):
def __init__(self, expected_sample_length: int):
super().__init__()
self.expected_sample_length = expected_sample_length
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
c_dim = 1 if len(waveform.shape) == 3 else 0
# Take out extra channels
if waveform.shape[c_dim] > 1:
waveform = waveform.mean(c_dim, keepdim=True)
# ensure it is the correct length
waveform = self._rectify_duration(waveform, c_dim)
return waveform
def _rectify_duration(self, waveform: torch.Tensor, channel_dim: int):
expected_samples = self.expected_sample_length
sample_count = waveform.shape[channel_dim + 1]
if expected_samples == sample_count:
return waveform
elif expected_samples > sample_count:
pad_amount = expected_samples - sample_count
return torch.nn.functional.pad(
waveform,
(channel_dim + 1) * [0] + [pad_amount],
mode="constant",
value=0.0,
)
else:
return waveform[:, :expected_samples]
class AudioToSpectrogram:
def __init__(
self,
sample_rate=16000,
):
self.spec = taT.MelSpectrogram(
sample_rate=sample_rate, n_mels=128, n_fft=1024
) # Note: this doesn't work on mps right now.
self.to_db = taT.AmplitudeToDB()
def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
spectrogram = self.spec(waveform)
spectrogram = self.to_db(spectrogram)
return spectrogram