Spaces:
Runtime error
Runtime error
File size: 4,792 Bytes
4b8361a 0030bc6 4b8361a 557fb53 0030bc6 b6800ef 248f682 557fb53 248f682 0030bc6 4b8361a b6800ef 0030bc6 248f682 0030bc6 4b8361a 557fb53 0030bc6 557fb53 0030bc6 4b8361a 557fb53 0030bc6 b6800ef 557fb53 248f682 557fb53 0030bc6 4b8361a 0030bc6 4b8361a ba35f85 0030bc6 557fb53 4b8361a 0030bc6 4b8361a 557fb53 51f4763 0030bc6 51f4763 4b8361a 0030bc6 51f4763 0030bc6 4b8361a 51f4763 0030bc6 51f4763 0030bc6 557fb53 51f4763 557fb53 0030bc6 557fb53 0030bc6 557fb53 0030bc6 557fb53 0030bc6 557fb53 0030bc6 557fb53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import torch
import torchaudio
from torchaudio import transforms as taT, functional as taF
import torch.nn as nn
class WaveformTrainingPipeline(torch.nn.Module):
def __init__(
self,
expected_duration=6,
snr_mean=6.0,
noise_path=None,
):
super().__init__()
self.snr_mean = snr_mean
self.noise = self.get_noise(noise_path)
self.sample_rate = 16000
self.preprocess_waveform = WaveformPreprocessing(
self.sample_rate * expected_duration
)
def get_noise(self, path) -> torch.Tensor:
if path is None:
return None
noise, sr = torchaudio.load(path)
if noise.shape[0] > 1:
noise = noise.mean(0, keepdim=True)
if sr != self.sample_rate:
noise = taF.resample(noise, sr, self.sample_rate)
return noise
def add_noise(self, waveform: torch.Tensor) -> torch.Tensor:
assert (
self.noise is not None
), "Cannot add noise because a noise file was not provided."
num_repeats = waveform.shape[1] // self.noise.shape[1] + 1
noise = self.noise.repeat(1, num_repeats)[:, : waveform.shape[1]]
noise_power = noise.norm(p=2)
signal_power = waveform.norm(p=2)
snr_db = torch.normal(self.snr_mean, 1.5, (1,)).clamp_min(1.0)
snr = torch.exp(snr_db / 10)
scale = snr * noise_power / signal_power
noisy_waveform = (scale * waveform + noise) / 2
return noisy_waveform
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
waveform = self.preprocess_waveform(waveform)
if self.noise is not None:
waveform = self.add_noise(waveform)
return waveform
class SpectrogramTrainingPipeline(WaveformTrainingPipeline):
def __init__(
self, freq_mask_size=10, time_mask_size=80, mask_count=2, *args, **kwargs
):
super().__init__(*args, **kwargs)
self.mask_count = mask_count
self.audio_to_spectrogram = AudioToSpectrogram(
sample_rate=self.sample_rate,
)
self.freq_mask = taT.FrequencyMasking(freq_mask_size)
self.time_mask = taT.TimeMasking(time_mask_size)
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
waveform = super().forward(waveform)
spec = self.audio_to_spectrogram(waveform)
# Spectrogram augmentation
for _ in range(self.mask_count):
spec = self.freq_mask(spec)
spec = self.time_mask(spec)
return spec
class SpectrogramProductionPipeline(torch.nn.Module):
def __init__(self, sample_rate=16000, expected_duration=6, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.preprocess_waveform = WaveformPreprocessing(
sample_rate * expected_duration
)
self.audio_to_spectrogram = AudioToSpectrogram(
sample_rate=sample_rate,
)
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
waveform = self.preprocess_waveform(waveform)
return self.audio_to_spectrogram(waveform)
class WaveformPreprocessing(torch.nn.Module):
def __init__(self, expected_sample_length: int):
super().__init__()
self.expected_sample_length = expected_sample_length
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
c_dim = 1 if len(waveform.shape) == 3 else 0
# Take out extra channels
if waveform.shape[c_dim] > 1:
waveform = waveform.mean(c_dim, keepdim=True)
# ensure it is the correct length
waveform = self._rectify_duration(waveform, c_dim)
return waveform
def _rectify_duration(self, waveform: torch.Tensor, channel_dim: int):
expected_samples = self.expected_sample_length
sample_count = waveform.shape[channel_dim + 1]
if expected_samples == sample_count:
return waveform
elif expected_samples > sample_count:
pad_amount = expected_samples - sample_count
return torch.nn.functional.pad(
waveform,
(channel_dim + 1) * [0] + [pad_amount],
mode="constant",
value=0.0,
)
else:
return waveform[:, :expected_samples]
class AudioToSpectrogram:
def __init__(
self,
sample_rate=16000,
):
self.spec = taT.MelSpectrogram(
sample_rate=sample_rate, n_mels=128, n_fft=1024
) # Note: this doesn't work on mps right now.
self.to_db = taT.AmplitudeToDB()
def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
spectrogram = self.spec(waveform)
spectrogram = self.to_db(spectrogram)
return spectrogram
|