Spaces:
Running
Running
File size: 3,660 Bytes
52f1bcb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
from itertools import chain
import random
from typing import List, Optional, Tuple, Union
from tokenizers import AddedToken
from transformers import ByT5Tokenizer
import numpy as np
import torch
from surya.model.recognition.config import LANGUAGE_MAP, TOTAL_TOKENS, TOKEN_OFFSET
def text_to_utf16_numbers(text):
utf16_bytes = text.encode('utf-16le') # Little-endian to simplify byte order handling
numbers = []
# Iterate through each pair of bytes and combine them into a single number
for i in range(0, len(utf16_bytes), 2):
# Combine two adjacent bytes into a single number
number = utf16_bytes[i] + (utf16_bytes[i + 1] << 8)
numbers.append(number)
return numbers
def utf16_numbers_to_text(numbers):
byte_array = bytearray()
for number in numbers:
# Extract the two bytes from the number and add them to the byte array
byte_array.append(number & 0xFF) # Lower byte
byte_array.append((number >> 8) & 0xFF) # Upper byte
text = byte_array.decode('utf-16le', errors="ignore")
return text
def _tokenize(text: str, langs: List[str] | None, eos_token_id: int = 1, add_eos: bool = False, add_bos: bool = True):
tokens = text_to_utf16_numbers(text)
tokens = [t + TOKEN_OFFSET for t in tokens] # Account for special pad, etc, tokens
lang_list = []
if langs:
for lang in langs:
code = LANGUAGE_MAP[lang]
lang_list.append(code + TOKEN_OFFSET + TOTAL_TOKENS)
tokens = lang_list + tokens
if add_bos:
tokens.insert(0, eos_token_id)
return tokens, lang_list
class Byt5LangTokenizer(ByT5Tokenizer):
def __init__(self,
eos_token="</s>",
unk_token="<unk>",
pad_token="<pad>",
model_max_length=None,
**kwargs,
):
self.pad_token = pad_token
self.eos_token = eos_token
self.unk_token = unk_token
self.bos_token = eos_token
self.offset = TOKEN_OFFSET
self.pad_id = 0
self.eos_id = 1
self.unk_id = 2
self.model_max_length = model_max_length
self.special_token_start = TOKEN_OFFSET + TOTAL_TOKENS
super().__init__()
def __call__(self, texts: List[str] | str, langs: List[List[str]] | List[str] | None = None, pad_token_id: int = 0, **kwargs):
tokenized = []
all_langs = []
is_list = True
# Convert to list of lists format
if isinstance(texts, str):
texts = [texts]
is_list = False
if langs is None:
langs = [None] * len(texts)
if isinstance(langs[0], str):
langs = [langs]
assert len(langs) == len(texts)
for text, lang in zip(texts, langs):
tokens, lang_list = _tokenize(text, lang)
tokenized.append(tokens)
all_langs.append(lang_list)
# Convert back to flat format
if not is_list:
tokenized = tokenized[0]
all_langs = all_langs[0]
return {"input_ids": tokenized, "langs": all_langs}
def decode(
self,
token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = None,
**kwargs,
) -> str:
if isinstance(token_ids, (np.ndarray, torch.Tensor)):
token_ids = token_ids.tolist()
token_ids = [t for t in token_ids if TOKEN_OFFSET <= t < self.special_token_start]
token_ids = [t - TOKEN_OFFSET for t in token_ids]
text = utf16_numbers_to_text(token_ids)
return text
|