import logging import random from typing import List import numpy as np import torch import torchaudio from torch import Tensor from torch.utils.data import Dataset from hw_asr.base.base_text_encoder import BaseTextEncoder from hw_asr.utils.parse_config import ConfigParser logger = logging.getLogger(__name__) class BaseDataset(Dataset): def __init__( self, index, text_encoder: BaseTextEncoder, config_parser: ConfigParser, wave_augs=None, spec_augs=None, limit=None, max_audio_length=None, max_text_length=None, ): self.text_encoder = text_encoder self.config_parser = config_parser self.wave_augs = wave_augs self.spec_augs = spec_augs self.log_spec = config_parser["preprocessing"]["log_spec"] self._assert_index_is_valid(index) index = self._filter_records_from_dataset(index, max_audio_length, max_text_length, limit) # it's a good idea to sort index by audio length # It would be easier to write length-based batch samplers later index = self._sort_index(index) self._index: List[dict] = index def __getitem__(self, ind): data_dict = self._index[ind] audio_path = data_dict["path"] audio_wave = self.load_audio(audio_path) audio_wave, audio_spec = self.process_wave(audio_wave) return { "audio": audio_wave, "spectrogram": audio_spec, "duration": audio_wave.size(1) / self.config_parser["preprocessing"]["sr"], "text": data_dict["text"], "text_encoded": self.text_encoder.encode(data_dict["text"]), "audio_path": audio_path, } @staticmethod def _sort_index(index): return sorted(index, key=lambda x: x["audio_len"]) def __len__(self): return len(self._index) def load_audio(self, path): audio_tensor, sr = torchaudio.load(path) audio_tensor = audio_tensor[0:1, :] # remove all channels but the first target_sr = self.config_parser["preprocessing"]["sr"] if sr != target_sr: audio_tensor = torchaudio.functional.resample(audio_tensor, sr, target_sr) return audio_tensor def process_wave(self, audio_tensor_wave: Tensor): with torch.no_grad(): if self.wave_augs is not None: audio_tensor_wave = self.wave_augs(audio_tensor_wave) wave2spec = self.config_parser.init_obj( self.config_parser["preprocessing"]["spectrogram"], torchaudio.transforms, ) audio_tensor_spec = wave2spec(audio_tensor_wave) if self.spec_augs is not None: audio_tensor_spec = self.spec_augs(audio_tensor_spec) if self.log_spec: audio_tensor_spec = torch.log(audio_tensor_spec + 1e-5) return audio_tensor_wave, audio_tensor_spec @staticmethod def _filter_records_from_dataset( index: list, max_audio_length, max_text_length, limit ) -> list: initial_size = len(index) if max_audio_length is not None: exceeds_audio_length = np.array([el["audio_len"] for el in index]) >= max_audio_length _total = exceeds_audio_length.sum() logger.info( f"{_total} ({_total / initial_size:.1%}) records are longer then " f"{max_audio_length} seconds. Excluding them." ) else: exceeds_audio_length = False initial_size = len(index) if max_text_length is not None: exceeds_text_length = ( np.array( [len(BaseTextEncoder.normalize_text(el["text"])) for el in index] ) >= max_text_length ) _total = exceeds_text_length.sum() logger.info( f"{_total} ({_total / initial_size:.1%}) records are longer then " f"{max_text_length} characters. Excluding them." ) else: exceeds_text_length = False records_to_filter = exceeds_text_length | exceeds_audio_length if records_to_filter is not False and records_to_filter.any(): _total = records_to_filter.sum() index = [el for el, exclude in zip(index, records_to_filter) if not exclude] logger.info( f"Filtered {_total}({_total / initial_size:.1%}) records from dataset" ) if limit is not None: random.seed(42) # best seed for deep learning random.shuffle(index) index = index[:limit] return index @staticmethod def _assert_index_is_valid(index): for entry in index: assert "audio_len" in entry, ( "Each dataset item should include field 'audio_len'" " - duration of audio (in seconds)." ) assert "path" in entry, ( "Each dataset item should include field 'path'" " - path to audio file." ) assert "text" in entry, ( "Each dataset item should include field 'text'" " - text transcription of the audio." )