Spaces:
Runtime error
Runtime error
File size: 3,213 Bytes
0d9af09 47bf442 0d9af09 3806d0c 85adfed 3806d0c 47bf442 0d9af09 a38e25f 47bf442 3806d0c 47bf442 0908871 0d9af09 40f7298 0d9af09 bc7eb76 0d9af09 40f7298 a38e25f 40f7298 47bf442 0d9af09 47bf442 0d9af09 47bf442 0d9af09 0908871 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import os
import torch
from torch.utils.data import Dataset
import pandas as pd
import torchaudio
class VoiceDataset(Dataset):
def __init__(
self,
data_directory,
transformation,
device,
target_sample_rate=48000,
time_limit_in_secs=5,
):
# file processing
self._data_path = os.path.join(data_directory)
self._labels = os.listdir(self._data_path)
self.label_mapping = {label: i for i, label in enumerate(self._labels)}
self.audio_files_labels = self._join_audio_files()
self.device = device
# audio processing
self.transformation = transformation
self.target_sample_rate = target_sample_rate
self.num_samples = time_limit_in_secs * self.target_sample_rate
# preprocess all wavs
self.wavs = self._process_wavs()
def __len__(self):
return len(self.audio_files_labels)
def __getitem__(self, index):
return self.wavs[index]
def _process_wavs(self):
wavs = []
for file, label in self.audio_files_labels:
filepath = os.path.join(self._data_path, label, file)
# load wav
wav, sr = torchaudio.load(filepath, normalize=True)
# modify wav file, if necessary
wav = wav.to(self.device)
wav = self._resample(wav, sr)
wav = self._mix_down(wav)
wav = self._cut_or_pad(wav)
# apply transformation
wav = self.transformation(wav)
wavs.append((wav, self.label_mapping[label]))
return wavs
def _join_audio_files(self):
"""Join all the audio file names and labels into one single dimenional array"""
audio_files_labels = []
for label in self._labels:
label_path = os.path.join(self._data_path, label)
for f in os.listdir(label_path):
audio_files_labels.append((f, label))
return audio_files_labels
def _resample(self, wav, current_sample_rate):
"""Resample audio to the target sample rate, if necessary"""
if current_sample_rate != self.target_sample_rate:
resampler = torchaudio.transforms.Resample(current_sample_rate, self.target_sample_rate)
wav = resampler(wav)
return wav
def _mix_down(self, wav):
"""Mix down audio to a single channel, if necessary"""
if wav.shape[0] > 1:
wav = torch.mean(wav, dim=0, keepdim=True)
return wav
def _cut_or_pad(self, wav):
"""Modify audio if number of samples != target number of samples of the dataset.
If there are too many samples, cut the audio.
If there are not enough samples, pad the audio with zeros.
"""
length_signal = wav.shape[1]
if length_signal > self.num_samples:
wav = wav[:, :self.num_samples]
elif length_signal < self.num_samples:
num_of_missing_samples = self.num_samples - length_signal
pad = (0, num_of_missing_samples)
wav = torch.nn.functional.pad(wav, pad)
return wav
|