Spaces:
Paused
Paused
import json | |
import random | |
import torch | |
import torchaudio | |
from torch.utils.data import Dataset | |
class AudioTextDataset(Dataset): | |
"""Can sample data from audio-text databases | |
Params: | |
sampling_rate: audio sampling rate | |
max_clip_len: max length (seconds) of audio clip to be sampled | |
""" | |
def __init__( | |
self, | |
datafiles=[''], | |
sampling_rate=32000, | |
max_clip_len=5, | |
): | |
all_data_json = [] | |
for datafile in datafiles: | |
with open(datafile, 'r') as fp: | |
data_json = json.load(fp)['data'] | |
all_data_json.extend(data_json) | |
self.all_data_json = all_data_json | |
self.sampling_rate = sampling_rate | |
self.max_length = max_clip_len * sampling_rate | |
def __len__(self): | |
return len(self.all_data_json) | |
def _cut_or_randomcrop(self, waveform): | |
# waveform: [1, samples] | |
# random crop | |
if waveform.size(1) > self.max_length: | |
random_idx = random.randint(0, waveform.size(1)-self.max_length) | |
waveform = waveform[:, random_idx:random_idx+self.max_length] | |
else: | |
temp_wav = torch.zeros(1, self.max_length) | |
temp_wav[:, 0:waveform.size(1)] = waveform | |
waveform = temp_wav | |
assert waveform.size(1) == self.max_length, \ | |
f"number of audio samples is {waveform.size(1)}" | |
return waveform | |
def _read_audio(self, index): | |
try: | |
audio_path = self.all_data_json[index]['wav'] | |
audio_data, audio_rate = torchaudio.load(audio_path, channels_first=True) | |
text = self.all_data_json[index]['caption'] | |
# drop short utterance | |
if audio_data.size(1) < self.sampling_rate * 1: | |
raise Exception(f'{audio_path} is too short, drop it ...') | |
return text, audio_data, audio_rate | |
except Exception as e: | |
print(f'error: {e} occurs, when loading {audio_path}') | |
random_index = random.randint(0, len(self.all_data_json)-1) | |
return self._read_audio(index=random_index) | |
def __getitem__(self, index): | |
# create a audio tensor | |
text, audio_data, audio_rate = self._read_audio(index) | |
audio_len = audio_data.shape[1] / audio_rate | |
# convert stero to single channel | |
if audio_data.shape[0] > 1: | |
# audio_data: [samples] | |
audio_data = (audio_data[0] + audio_data[1]) / 2 | |
else: | |
audio_data = audio_data.squeeze(0) | |
# resample audio clip | |
if audio_rate != self.sampling_rate: | |
audio_data = torchaudio.functional.resample(audio_data, orig_freq=audio_rate, new_freq=self.sampling_rate) | |
audio_data = audio_data.unsqueeze(0) | |
audio_data = self._cut_or_randomcrop(audio_data) | |
data_dict = { | |
'text': text, | |
'waveform': audio_data, | |
'modality': 'audio_text' | |
} | |
return data_dict | |