waidhoferj's picture
updated paths to work remotely
b6800ef
raw
history blame
3.95 kB
import torch
import torchaudio
from torchaudio import transforms as taT, functional as taF
import torch.nn as nn
class AudioTrainingPipeline(torch.nn.Module):
def __init__(self,
input_freq=16000,
resample_freq=16000,
expected_duration=6,
freq_mask_size=10,
time_mask_size=80,
mask_count = 2,
snr_mean=6.0,
noise_path=None):
super().__init__()
self.input_freq = input_freq
self.snr_mean = snr_mean
self.mask_count = mask_count
self.noise = self.get_noise(noise_path)
self.resample = taT.Resample(input_freq,resample_freq)
self.preprocess_waveform = WaveformPreprocessing(resample_freq * expected_duration)
self.audio_to_spectrogram = AudioToSpectrogram(
sample_rate=resample_freq,
)
self.freq_mask = taT.FrequencyMasking(freq_mask_size)
self.time_mask = taT.TimeMasking(time_mask_size)
def get_noise(self, path) -> torch.Tensor:
if path is None:
return None
noise, sr = torchaudio.load(path)
if noise.shape[0] > 1:
noise = noise.mean(0, keepdim=True)
if sr != self.input_freq:
noise = taF.resample(noise,sr, self.input_freq)
return noise
def add_noise(self, waveform:torch.Tensor) -> torch.Tensor:
assert self.noise is not None, "Cannot add noise because a noise file was not provided."
num_repeats = waveform.shape[1] // self.noise.shape[1] + 1
noise = self.noise.repeat(1,num_repeats)[:, :waveform.shape[1]]
noise_power = noise.norm(p=2)
signal_power = waveform.norm(p=2)
snr_db = torch.normal(self.snr_mean, 1.5, (1,)).clamp_min(1.0)
snr = torch.exp(snr_db / 10)
scale = snr * noise_power / signal_power
noisy_waveform = (scale * waveform + noise) / 2
return noisy_waveform
def forward(self, waveform:torch.Tensor) -> torch.Tensor:
try:
waveform = self.resample(waveform)
except:
print("oops")
waveform = self.preprocess_waveform(waveform)
if self.noise is not None:
waveform = self.add_noise(waveform)
spec = self.audio_to_spectrogram(waveform)
# Spectrogram augmentation
for _ in range(self.mask_count):
spec = self.freq_mask(spec)
spec = self.time_mask(spec)
return spec
class WaveformPreprocessing(torch.nn.Module):
def __init__(self, expected_sample_length:int):
super().__init__()
self.expected_sample_length = expected_sample_length
def forward(self, waveform:torch.Tensor) -> torch.Tensor:
# Take out extra channels
if waveform.shape[0] > 1:
waveform = waveform.mean(0, keepdim=True)
# ensure it is the correct length
waveform = self._rectify_duration(waveform)
return waveform
def _rectify_duration(self,waveform:torch.Tensor):
expected_samples = self.expected_sample_length
sample_count = waveform.shape[1]
if expected_samples == sample_count:
return waveform
elif expected_samples > sample_count:
pad_amount = expected_samples - sample_count
return torch.nn.functional.pad(waveform, (0, pad_amount),mode="constant", value=0.0)
else:
return waveform[:,:expected_samples]
class AudioToSpectrogram(torch.nn.Module):
def __init__(
self,
sample_rate=16000,
):
super().__init__()
self.spec = taT.MelSpectrogram(sample_rate=sample_rate, n_mels=128, n_fft=1024) # TODO: Change mels to 64
self.to_db = taT.AmplitudeToDB()
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
spectrogram = self.spec(waveform)
spectrogram = self.to_db(spectrogram)
return spectrogram