File size: 3,213 Bytes
0d9af09
 
47bf442
0d9af09
 
 
 
 
 
3806d0c
 
 
 
 
85adfed
3806d0c
 
47bf442
0d9af09
 
a38e25f
47bf442
 
3806d0c
 
47bf442
 
 
0908871
0d9af09
40f7298
 
 
0d9af09
bc7eb76
0d9af09
 
40f7298
 
 
 
 
 
a38e25f
40f7298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47bf442
0d9af09
 
47bf442
 
0d9af09
 
 
 
47bf442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d9af09
0908871
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os

import torch
from torch.utils.data import Dataset
import pandas as pd
import torchaudio

class VoiceDataset(Dataset):

    def __init__(
            self,
            data_directory,
            transformation,
            device,
            target_sample_rate=48000,
            time_limit_in_secs=5,
        ):
        # file processing
        self._data_path = os.path.join(data_directory)
        self._labels = os.listdir(self._data_path)
        self.label_mapping = {label: i for i, label in enumerate(self._labels)}
        self.audio_files_labels = self._join_audio_files()

        self.device = device

        # audio processing
        self.transformation = transformation
        self.target_sample_rate = target_sample_rate
        self.num_samples = time_limit_in_secs * self.target_sample_rate

        # preprocess all wavs
        self.wavs = self._process_wavs()

    def __len__(self):
        return len(self.audio_files_labels)

    def __getitem__(self, index):
        return self.wavs[index]

    def _process_wavs(self):
        wavs = []
        for file, label in self.audio_files_labels:
            filepath = os.path.join(self._data_path, label, file)

            # load wav
            wav, sr = torchaudio.load(filepath, normalize=True)

            # modify wav file, if necessary
            wav = wav.to(self.device)
            wav = self._resample(wav, sr)
            wav = self._mix_down(wav)
            wav = self._cut_or_pad(wav)
            
            # apply transformation
            wav = self.transformation(wav)

            wavs.append((wav, self.label_mapping[label]))
        
        return wavs


    def _join_audio_files(self):
        """Join all the audio file names and labels into one single dimenional array"""
        audio_files_labels = []

        for label in self._labels:
            label_path = os.path.join(self._data_path, label)
            for f in os.listdir(label_path):
                audio_files_labels.append((f, label))

        return audio_files_labels

    def _resample(self, wav, current_sample_rate):
        """Resample audio to the target sample rate, if necessary"""
        if current_sample_rate != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(current_sample_rate, self.target_sample_rate)
            wav = resampler(wav)
        
        return wav

    def _mix_down(self, wav):
        """Mix down audio to a single channel, if necessary"""
        if wav.shape[0] > 1:
            wav = torch.mean(wav, dim=0, keepdim=True)

        return wav

    def _cut_or_pad(self, wav):
        """Modify audio if number of samples != target number of samples of the dataset.

        If there are too many samples, cut the audio.
        If there are not enough samples, pad the audio with zeros.
        """

        length_signal =  wav.shape[1]
        if length_signal > self.num_samples:
            wav = wav[:, :self.num_samples]
        elif length_signal < self.num_samples:
            num_of_missing_samples = self.num_samples - length_signal
            pad = (0, num_of_missing_samples)
            wav = torch.nn.functional.pad(wav, pad)

        return wav