tgritsaev's picture
Upload 198 files
affcd23 verified
raw
history blame
1.61 kB
import json
from pathlib import Path
from string import ascii_lowercase
from typing import List, Union
import numpy as np
from torch import Tensor
from hw_asr.base.base_text_encoder import BaseTextEncoder
class CharTextEncoder(BaseTextEncoder):
def __init__(self, alphabet: List[str] = None):
if alphabet is None:
alphabet = list(ascii_lowercase + ' ')
self.alphabet = alphabet
self.ind2char = {k: v for k, v in enumerate(sorted(alphabet))}
self.char2ind = {v: k for k, v in self.ind2char.items()}
def __len__(self):
return len(self.ind2char)
def __getitem__(self, item: int):
assert type(item) is int
return self.ind2char[item]
def encode(self, text) -> Tensor:
text = self.normalize_text(text)
try:
return Tensor([self.char2ind[char] for char in text]).unsqueeze(0)
except KeyError as e:
unknown_chars = set([char for char in text if char not in self.char2ind])
raise Exception(
f"Can't encode text '{text}'. Unknown chars: '{' '.join(unknown_chars)}'")
def decode(self, vector: Union[Tensor, np.ndarray, List[int]]):
return ''.join([self.ind2char[int(ind)] for ind in vector]).strip()
def dump(self, file):
with Path(file).open('w') as f:
json.dump(self.ind2char, f)
@classmethod
def from_file(cls, file):
with Path(file).open() as f:
ind2char = json.load(f)
a = cls([])
a.ind2char = ind2char
a.char2ind = {v: k for k, v in ind2char}
return a