|
import torch |
|
import torch.nn as nn |
|
from torchaudio import transforms as T |
|
|
|
|
|
class PadCrop(nn.Module): |
|
def __init__(self, n_samples, randomize=True): |
|
super().__init__() |
|
self.n_samples = n_samples |
|
self.randomize = randomize |
|
|
|
def __call__(self, signal): |
|
n, s = signal.shape |
|
start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item() |
|
end = start + self.n_samples |
|
output = signal.new_zeros([n, self.n_samples]) |
|
output[:, :min(s, self.n_samples)] = signal[:, start:end] |
|
return output |
|
|
|
|
|
def set_audio_channels(audio, target_channels): |
|
if target_channels == 1: |
|
|
|
audio = audio.mean(1, keepdim=True) |
|
elif target_channels == 2: |
|
|
|
if audio.shape[1] == 1: |
|
audio = audio.repeat(1, 2, 1) |
|
elif audio.shape[1] > 2: |
|
audio = audio[:, :2, :] |
|
return audio |
|
|
|
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device): |
|
|
|
audio = audio.to(device) |
|
|
|
if in_sr != target_sr: |
|
resample_tf = T.Resample(in_sr, target_sr).to(device) |
|
audio = resample_tf(audio) |
|
|
|
audio = PadCrop(target_length, randomize=False)(audio) |
|
|
|
|
|
if audio.dim() == 1: |
|
audio = audio.unsqueeze(0).unsqueeze(0) |
|
elif audio.dim() == 2: |
|
audio = audio.unsqueeze(0) |
|
|
|
audio = set_audio_channels(audio, target_channels) |
|
|
|
return audio |