lmzjms's picture
Upload 1162 files
0b32ad6 verified
import re
from pathlib import Path
from torch.utils.data.dataset import Dataset
from torchaudio.sox_effects import apply_effects_file
class QUESST14Dataset(Dataset):
"""QUESST 2014 dataset (English-only)."""
def __init__(self, split, **kwargs):
dataset_root = Path(kwargs["dataset_root"])
doc_paths = english_audio_paths(dataset_root, "language_key_utterances.lst")
query_paths = english_audio_paths(dataset_root, f"language_key_{split}.lst")
self.dataset_root = dataset_root
self.n_queries = len(query_paths)
self.n_docs = len(doc_paths)
self.data = query_paths + doc_paths
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
audio_path = self.data[idx]
wav, _ = apply_effects_file(
str(audio_path),
[
["channels", "1"],
["rate", "16000"],
["gain", "-3.0"],
],
)
wav = wav.squeeze(0)
return wav.numpy(), audio_path.with_suffix("").name
def collate_fn(self, samples):
"""Collate a mini-batch of data."""
wavs, audio_names = zip(*samples)
return wavs, audio_names
def english_audio_paths(dataset_root_path, lst_name):
"""Extract English audio paths."""
audio_paths = []
with open(dataset_root_path / "scoring" / lst_name) as f:
for line in f:
audio_path, lang = tuple(line.strip().split())
if lang != "nnenglish":
continue
audio_path = re.sub(r"^.*?\/", "", audio_path)
audio_paths.append(dataset_root_path / audio_path)
return audio_paths