import json from pathlib import Path from string import ascii_lowercase from typing import List, Union import numpy as np from torch import Tensor from hw_asr.base.base_text_encoder import BaseTextEncoder class CharTextEncoder(BaseTextEncoder): def __init__(self, alphabet: List[str] = None): if alphabet is None: alphabet = list(ascii_lowercase + ' ') self.alphabet = alphabet self.ind2char = {k: v for k, v in enumerate(sorted(alphabet))} self.char2ind = {v: k for k, v in self.ind2char.items()} def __len__(self): return len(self.ind2char) def __getitem__(self, item: int): assert type(item) is int return self.ind2char[item] def encode(self, text) -> Tensor: text = self.normalize_text(text) try: return Tensor([self.char2ind[char] for char in text]).unsqueeze(0) except KeyError as e: unknown_chars = set([char for char in text if char not in self.char2ind]) raise Exception( f"Can't encode text '{text}'. Unknown chars: '{' '.join(unknown_chars)}'") def decode(self, vector: Union[Tensor, np.ndarray, List[int]]): return ''.join([self.ind2char[int(ind)] for ind in vector]).strip() def dump(self, file): with Path(file).open('w') as f: json.dump(self.ind2char, f) @classmethod def from_file(cls, file): with Path(file).open() as f: ind2char = json.load(f) a = cls([]) a.ind2char = ind2char a.char2ind = {v: k for k, v in ind2char} return a