|
import json |
|
from pathlib import Path |
|
from string import ascii_lowercase |
|
from typing import List, Union |
|
|
|
import numpy as np |
|
from torch import Tensor |
|
|
|
from hw_asr.base.base_text_encoder import BaseTextEncoder |
|
|
|
|
|
class CharTextEncoder(BaseTextEncoder): |
|
|
|
def __init__(self, alphabet: List[str] = None): |
|
if alphabet is None: |
|
alphabet = list(ascii_lowercase + ' ') |
|
self.alphabet = alphabet |
|
self.ind2char = {k: v for k, v in enumerate(sorted(alphabet))} |
|
self.char2ind = {v: k for k, v in self.ind2char.items()} |
|
|
|
def __len__(self): |
|
return len(self.ind2char) |
|
|
|
def __getitem__(self, item: int): |
|
assert type(item) is int |
|
return self.ind2char[item] |
|
|
|
def encode(self, text) -> Tensor: |
|
text = self.normalize_text(text) |
|
try: |
|
return Tensor([self.char2ind[char] for char in text]).unsqueeze(0) |
|
except KeyError as e: |
|
unknown_chars = set([char for char in text if char not in self.char2ind]) |
|
raise Exception( |
|
f"Can't encode text '{text}'. Unknown chars: '{' '.join(unknown_chars)}'") |
|
|
|
def decode(self, vector: Union[Tensor, np.ndarray, List[int]]): |
|
return ''.join([self.ind2char[int(ind)] for ind in vector]).strip() |
|
|
|
def dump(self, file): |
|
with Path(file).open('w') as f: |
|
json.dump(self.ind2char, f) |
|
|
|
@classmethod |
|
def from_file(cls, file): |
|
with Path(file).open() as f: |
|
ind2char = json.load(f) |
|
a = cls([]) |
|
a.ind2char = ind2char |
|
a.char2ind = {v: k for k, v in ind2char} |
|
return a |
|
|