lmzjms's picture
Upload 1162 files
0b32ad6 verified
import re
import xml.etree.ElementTree as ET
from pathlib import Path
from torch.utils.data.dataset import Dataset
from torchaudio.sox_effects import apply_effects_file
class SWS2013Testset(Dataset):
"""SWS 2013 testset."""
def __init__(self, split, **kwargs):
assert split in ["dev", "eval"]
scoring_root = Path(kwargs["sws2013_scoring_root"])
audio_names = parse_ecf(scoring_root / f"sws2013_{split}" / "sws2013.ecf.xml")
query_names = parse_tlist(
scoring_root / f"sws2013_{split}" / f"sws2013_{split}.tlist.xml"
)
self.dataset_root = Path(kwargs["sws2013_root"])
self.split = split
self.n_queries = len(query_names)
self.n_docs = len(audio_names)
self.data = query_names + audio_names
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
audio_name = self.data[idx]
audio_path = (
(self.dataset_root / f"{self.split}_queries" / audio_name)
if idx < self.n_queries
else (self.dataset_root / "Audio" / audio_name)
)
audio_path = audio_path.with_suffix(".wav")
wav, _ = apply_effects_file(
str(audio_path),
[
["channels", "1"],
["rate", "16000"],
["norm"],
["vad", "-T", "0.25", "-p", "0.1"],
["reverse"],
["vad", "-T", "0.25", "-p", "0.1"],
["reverse"],
["pad", "0", "3"],
],
)
segments = wav.squeeze(0).unfold(0, 48000, 12000).unbind(0)
return segments, len(segments), audio_name
def collate_fn(self, samples):
"""Collate a mini-batch of data."""
segments, lengths, audio_names = zip(*samples)
segments = [seg for segs in segments for seg in segs]
return segments, (lengths, audio_names)
def parse_ecf(ecf_path):
"""Find audio paths from sws2013.ecf.xml."""
root = ET.parse(str(ecf_path)).getroot()
audio_names = []
for excerpt in root.findall("excerpt"):
audio_name = (
excerpt.attrib["audio_filename"].replace("Audio/", "").replace(".wav", "")
)
audio_names.append(audio_name)
return audio_names
def parse_tlist(tlist_path):
"""Find audio paths from sws2013_eval.tlist.xml."""
root = ET.parse(str(tlist_path)).getroot()
audio_names = []
for term in root.findall("term"):
audio_names.append(term.attrib["termid"])
return audio_names