tgritsaev's picture
Upload 198 files
affcd23 verified
import json
import logging
import os
import shutil
from pathlib import Path
import torchaudio
from speechbrain.utils.data_utils import download_file
from tqdm import tqdm
from hw_asr.base.base_dataset import BaseDataset
from hw_asr.utils import ROOT_PATH
logger = logging.getLogger(__name__)
URL_LINKS = {
"dev-clean": "https://www.openslr.org/resources/12/dev-clean.tar.gz",
"dev-other": "https://www.openslr.org/resources/12/dev-other.tar.gz",
"test-clean": "https://www.openslr.org/resources/12/test-clean.tar.gz",
"test-other": "https://www.openslr.org/resources/12/test-other.tar.gz",
"train-clean-100": "https://www.openslr.org/resources/12/train-clean-100.tar.gz",
"train-clean-360": "https://www.openslr.org/resources/12/train-clean-360.tar.gz",
"train-other-500": "https://www.openslr.org/resources/12/train-other-500.tar.gz",
}
class LibrispeechDataset(BaseDataset):
def __init__(self, part, data_dir=None, *args, **kwargs):
assert part in URL_LINKS or part == 'train_all'
if data_dir is None:
data_dir = ROOT_PATH / "data" / "datasets" / "librispeech"
data_dir.mkdir(exist_ok=True, parents=True)
self._data_dir = data_dir
if part == 'train_all':
index = sum([self._get_or_load_index(part)
for part in URL_LINKS if 'train' in part], [])
else:
index = self._get_or_load_index(part)
super().__init__(index, *args, **kwargs)
def _load_part(self, part):
arch_path = self._data_dir / f"{part}.tar.gz"
print(f"Loading part {part}")
download_file(URL_LINKS[part], arch_path)
shutil.unpack_archive(arch_path, self._data_dir)
for fpath in (self._data_dir / "LibriSpeech").iterdir():
shutil.move(str(fpath), str(self._data_dir / fpath.name))
os.remove(str(arch_path))
shutil.rmtree(str(self._data_dir / "LibriSpeech"))
def _get_or_load_index(self, part):
index_path = self._data_dir / f"{part}_index.json"
if index_path.exists():
with index_path.open() as f:
index = json.load(f)
else:
index = self._create_index(part)
with index_path.open("w") as f:
json.dump(index, f, indent=2)
return index
def _create_index(self, part):
index = []
split_dir = self._data_dir / part
if not split_dir.exists():
self._load_part(part)
flac_dirs = set()
for dirpath, dirnames, filenames in os.walk(str(split_dir)):
if any([f.endswith(".flac") for f in filenames]):
flac_dirs.add(dirpath)
for flac_dir in tqdm(
list(flac_dirs), desc=f"Preparing librispeech folders: {part}"
):
flac_dir = Path(flac_dir)
trans_path = list(flac_dir.glob("*.trans.txt"))[0]
with trans_path.open() as f:
for line in f:
f_id = line.split()[0]
f_text = " ".join(line.split()[1:]).strip()
flac_path = flac_dir / f"{f_id}.flac"
t_info = torchaudio.info(str(flac_path))
length = t_info.num_frames / t_info.sample_rate
index.append(
{
"path": str(flac_path.absolute().resolve()),
"text": f_text.lower(),
"audio_len": length,
}
)
return index