from itertools import chain import random from typing import List, Optional, Tuple, Union from tokenizers import AddedToken from transformers import ByT5Tokenizer import numpy as np import torch from surya.model.recognition.config import LANGUAGE_MAP, TOTAL_TOKENS, TOKEN_OFFSET def text_to_utf16_numbers(text): utf16_bytes = text.encode('utf-16le') # Little-endian to simplify byte order handling numbers = [] # Iterate through each pair of bytes and combine them into a single number for i in range(0, len(utf16_bytes), 2): # Combine two adjacent bytes into a single number number = utf16_bytes[i] + (utf16_bytes[i + 1] << 8) numbers.append(number) return numbers def utf16_numbers_to_text(numbers): byte_array = bytearray() for number in numbers: # Extract the two bytes from the number and add them to the byte array byte_array.append(number & 0xFF) # Lower byte byte_array.append((number >> 8) & 0xFF) # Upper byte text = byte_array.decode('utf-16le', errors="ignore") return text def _tokenize(text: str, langs: List[str] | None, eos_token_id: int = 1, add_eos: bool = False, add_bos: bool = True): tokens = text_to_utf16_numbers(text) tokens = [t + TOKEN_OFFSET for t in tokens] # Account for special pad, etc, tokens lang_list = [] if langs: for lang in langs: code = LANGUAGE_MAP[lang] lang_list.append(code + TOKEN_OFFSET + TOTAL_TOKENS) tokens = lang_list + tokens if add_bos: tokens.insert(0, eos_token_id) return tokens, lang_list class Byt5LangTokenizer(ByT5Tokenizer): def __init__(self, eos_token="", unk_token="", pad_token="", model_max_length=None, **kwargs, ): self.pad_token = pad_token self.eos_token = eos_token self.unk_token = unk_token self.bos_token = eos_token self.offset = TOKEN_OFFSET self.pad_id = 0 self.eos_id = 1 self.unk_id = 2 self.model_max_length = model_max_length self.special_token_start = TOKEN_OFFSET + TOTAL_TOKENS super().__init__() def __call__(self, texts: List[str] | str, langs: List[List[str]] | List[str] | None = None, pad_token_id: int = 0, **kwargs): tokenized = [] all_langs = [] is_list = True # Convert to list of lists format if isinstance(texts, str): texts = [texts] is_list = False if langs is None: langs = [None] * len(texts) if isinstance(langs[0], str): langs = [langs] assert len(langs) == len(texts) for text, lang in zip(texts, langs): tokens, lang_list = _tokenize(text, lang) tokenized.append(tokens) all_langs.append(lang_list) # Convert back to flat format if not is_list: tokenized = tokenized[0] all_langs = all_langs[0] return {"input_ids": tokenized, "langs": all_langs} def decode( self, token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = None, **kwargs, ) -> str: if isinstance(token_ids, (np.ndarray, torch.Tensor)): token_ids = token_ids.tolist() token_ids = [t for t in token_ids if TOKEN_OFFSET <= t < self.special_token_start] token_ids = [t - TOKEN_OFFSET for t in token_ids] text = utf16_numbers_to_text(token_ids) return text