Copy / tokenizer_base.py
murtazadahmardeh's picture
Upload
6b913da
import re
from abc import ABC, abstractmethod
from itertools import groupby
from typing import List, Optional, Tuple
import torch
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
class CharsetAdapter:
"""Transforms labels according to the target charset."""
def __init__(self, target_charset) -> None:
super().__init__()
self.charset = target_charset ###
self.lowercase_only = target_charset == target_charset.lower()
self.uppercase_only = target_charset == target_charset.upper()
# self.unsupported = f'[^{re.escape(target_charset)}]'
def __call__(self, label):
if self.lowercase_only:
label = label.lower()
elif self.uppercase_only:
label = label.upper()
return label
class BaseTokenizer(ABC):
def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None:
self._itos = specials_first + tuple(charset+'[UNK]') + specials_last
self._stoi = {s: i for i, s in enumerate(self._itos)}
def __len__(self):
return len(self._itos)
def _tok2ids(self, tokens: str) -> List[int]:
return [self._stoi[s] for s in tokens]
def _ids2tok(self, token_ids: List[int], join: bool = True) -> str:
tokens = [self._itos[i] for i in token_ids]
return ''.join(tokens) if join else tokens
@abstractmethod
def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
"""Encode a batch of labels to a representation suitable for the model.
Args:
labels: List of labels. Each can be of arbitrary length.
device: Create tensor on this device.
Returns:
Batched tensor representation padded to the max label length. Shape: N, L
"""
raise NotImplementedError
@abstractmethod
def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
"""Internal method which performs the necessary filtering prior to decoding."""
raise NotImplementedError
def decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]:
"""Decode a batch of token distributions.
Args:
token_dists: softmax probabilities over the token distribution. Shape: N, L, C
raw: return unprocessed labels (will return list of list of strings)
Returns:
list of string labels (arbitrary length) and
their corresponding sequence probabilities as a list of Tensors
"""
batch_tokens = []
batch_probs = []
for dist in token_dists:
probs, ids = dist.max(-1) # greedy selection
if not raw:
probs, ids = self._filter(probs, ids)
tokens = self._ids2tok(ids, not raw)
batch_tokens.append(tokens)
batch_probs.append(probs)
return batch_tokens, batch_probs
class Tokenizer(BaseTokenizer):
BOS = '[B]'
EOS = '[E]'
PAD = '[P]'
def __init__(self, charset: str) -> None:
specials_first = (self.EOS,)
specials_last = (self.BOS, self.PAD)
super().__init__(charset, specials_first, specials_last)
self.eos_id, self.bos_id, self.pad_id = [self._stoi[s] for s in specials_first + specials_last]
def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
batch = [torch.as_tensor([self.bos_id] + self._tok2ids(y) + [self.eos_id], dtype=torch.long, device=device)
for y in labels]
return pad_sequence(batch, batch_first=True, padding_value=self.pad_id)
def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
ids = ids.tolist()
try:
eos_idx = ids.index(self.eos_id)
except ValueError:
eos_idx = len(ids) # Nothing to truncate.
# Truncate after EOS
ids = ids[:eos_idx]
probs = probs[:eos_idx + 1] # but include prob. for EOS (if it exists)
return probs, ids
class CTCTokenizer(BaseTokenizer):
BLANK = '[B]'
def __init__(self, charset: str) -> None:
# BLANK uses index == 0 by default
super().__init__(charset, specials_first=(self.BLANK,))
self.blank_id = self._stoi[self.BLANK]
def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
# We use a padded representation since we don't want to use CUDNN's CTC implementation
batch = [torch.as_tensor(self._tok2ids(y), dtype=torch.long, device=device) for y in labels]
return pad_sequence(batch, batch_first=True, padding_value=self.blank_id)
def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
# Best path decoding:
ids = list(zip(*groupby(ids.tolist())))[0] # Remove duplicate tokens
ids = [x for x in ids if x != self.blank_id] # Remove BLANKs
# `probs` is just pass-through since all positions are considered part of the path
return probs, ids