Spaces:
Running
Running
from dataclasses import dataclass | |
import torch | |
from transformers import PretrainedConfig | |
from transformers.utils import ModelOutput | |
class SuryaOCRConfig(PretrainedConfig): | |
model_type = "vision-encoder-decoder" | |
is_composition = True | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
encoder_config = kwargs.pop("encoder") | |
decoder_config = kwargs.pop("decoder") | |
self.encoder = encoder_config | |
self.decoder = decoder_config | |
self.is_encoder_decoder = True | |
if isinstance(decoder_config, dict): | |
self.decoder_start_token_id = decoder_config["bos_token_id"] | |
self.pad_token_id = decoder_config["pad_token_id"] | |
self.eos_token_id = decoder_config["eos_token_id"] | |
else: | |
self.decoder_start_token_id = decoder_config.bos_token_id | |
self.pad_token_id = decoder_config.pad_token_id | |
self.eos_token_id = decoder_config.eos_token_id | |
class DonutSwinConfig(PretrainedConfig): | |
model_type = "donut-swin" | |
attribute_map = { | |
"num_attention_heads": "num_heads", | |
"num_hidden_layers": "num_layers", | |
} | |
def __init__( | |
self, | |
image_size=(256, 896), | |
patch_size=4, | |
num_channels=3, | |
embed_dim=128, | |
depths=[2, 2, 14, 2], | |
num_heads=[4, 8, 16, 32], | |
num_kv_heads=[1, 2, 4, 8], | |
window_size=7, | |
mlp_ratio=4.0, | |
qkv_bias=True, | |
hidden_dropout_prob=0.0, | |
attention_probs_dropout_prob=0.0, | |
drop_path_rate=0.1, | |
hidden_act="gelu", | |
use_absolute_embeddings=True, | |
initializer_range=0.02, | |
layer_norm_eps=1e-5, | |
encoder_length=256, | |
**kwargs, | |
): | |
super().__init__(**kwargs) | |
self.image_size = image_size | |
self.patch_size = patch_size | |
self.num_channels = num_channels | |
self.embed_dim = embed_dim | |
self.depths = depths | |
self.num_layers = len(depths) | |
self.num_heads = num_heads | |
self.num_kv_heads = num_kv_heads | |
self.window_size = window_size | |
self.mlp_ratio = mlp_ratio | |
self.qkv_bias = qkv_bias | |
self.hidden_dropout_prob = hidden_dropout_prob | |
self.attention_probs_dropout_prob = attention_probs_dropout_prob | |
self.drop_path_rate = drop_path_rate | |
self.hidden_act = hidden_act | |
self.use_absolute_embeddings = use_absolute_embeddings | |
self.layer_norm_eps = layer_norm_eps | |
self.initializer_range = initializer_range | |
# we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel | |
# this indicates the channel dimension after the last stage of the model | |
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) | |
self.encoder_length = encoder_length | |
class SuryaOCRDecoderConfig(PretrainedConfig): | |
model_type = "surya_ocr" | |
def __init__( | |
self, | |
num_hidden_layers=10, | |
vocab_size=65792, | |
hidden_size=1024, | |
intermediate_size=4 * 1024, | |
num_attention_heads=16, | |
lru_width=None, | |
attention_window_size=16, | |
conv1d_width=4, | |
logits_soft_cap=30.0, | |
rms_norm_eps=1e-6, | |
use_cache=True, | |
pad_token_id=0, | |
eos_token_id=1, | |
bos_token_id=1, | |
hidden_activation="gelu_pytorch_tanh", | |
rope_theta=10000.0, | |
block_types=("attention",), | |
cross_attn_layers=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), | |
self_attn_layers=(0, 1, 3, 5, 7, 9), | |
global_attn_layers=(0, 1, 3, 5, 7, 9), | |
attention_dropout=0.0, | |
num_key_value_heads=2, | |
attention_bias=False, | |
w_init_variance_scale=0.01, | |
init_std=0.02, | |
tie_word_embeddings=False, | |
aux_heads=0, # How many n-token-ahead heads to add | |
encoder_hidden_size=1024, | |
causal=False, | |
**kwargs, | |
): | |
self.num_hidden_layers = num_hidden_layers | |
self.vocab_size = vocab_size | |
self.hidden_size = hidden_size | |
self.intermediate_size = intermediate_size | |
self.num_attention_heads = num_attention_heads | |
self.lru_width = lru_width if lru_width is not None else hidden_size | |
self.attention_window_size = attention_window_size | |
self.conv1d_width = conv1d_width | |
self.logits_soft_cap = logits_soft_cap | |
self.rms_norm_eps = rms_norm_eps | |
self.use_cache = use_cache | |
self.rope_theta = rope_theta | |
self.block_types = list(block_types) | |
self.hidden_activation = hidden_activation | |
self.head_dim = self.hidden_size // self.num_attention_heads | |
self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads | |
if self.num_key_value_heads > self.num_attention_heads: | |
raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`") | |
self.cross_attn_layers = cross_attn_layers | |
self.self_attn_layers = self_attn_layers | |
self.global_attn_layers = global_attn_layers | |
self.attention_dropout = attention_dropout | |
self.attention_bias = attention_bias | |
self.w_init_variance_scale = w_init_variance_scale | |
self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers | |
self.init_std = init_std | |
self.tie_word_embeddings = tie_word_embeddings | |
self.aux_heads = aux_heads | |
self.encoder_hidden_size = encoder_hidden_size | |
self.causal = causal | |
super().__init__( | |
pad_token_id=pad_token_id, | |
bos_token_id=bos_token_id, | |
eos_token_id=eos_token_id, | |
**kwargs, | |
) | |
def layers_block_type(self): | |
return (self.block_types * 100)[: self.num_hidden_layers] | |
class SuryaOCRTextEncoderConfig(PretrainedConfig): | |
model_type = "surya_ocr" | |
def __init__( | |
self, | |
num_hidden_layers=10, | |
vocab_size=65792, | |
hidden_size=1024, | |
intermediate_size=4 * 1024, | |
num_attention_heads=16, | |
lru_width=None, | |
attention_window_size=16, | |
conv1d_width=4, | |
logits_soft_cap=30.0, | |
rms_norm_eps=1e-6, | |
use_cache=True, | |
pad_token_id=0, | |
eos_token_id=1, | |
bos_token_id=1, | |
hidden_activation="gelu_pytorch_tanh", | |
rope_theta=10000.0, | |
block_types=("attention",), | |
cross_attn_layers=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), | |
self_attn_layers=(0, 1, 3, 5, 7, 9), | |
global_attn_layers=(0, 1, 3, 5, 7, 9), | |
attention_dropout=0.0, | |
num_key_value_heads=2, | |
attention_bias=False, | |
w_init_variance_scale=0.01, | |
init_std=0.02, | |
tie_word_embeddings=False, | |
aux_heads=0, # How many n-token-ahead heads to add | |
encoder_hidden_size=1024, | |
iteration_count=1, | |
causal=False, | |
query_token_count=128, | |
**kwargs, | |
): | |
self.num_hidden_layers = num_hidden_layers | |
self.vocab_size = vocab_size | |
self.hidden_size = hidden_size | |
self.intermediate_size = intermediate_size | |
self.num_attention_heads = num_attention_heads | |
self.lru_width = lru_width if lru_width is not None else hidden_size | |
self.attention_window_size = attention_window_size | |
self.conv1d_width = conv1d_width | |
self.logits_soft_cap = logits_soft_cap | |
self.rms_norm_eps = rms_norm_eps | |
self.use_cache = use_cache | |
self.rope_theta = rope_theta | |
self.block_types = list(block_types) | |
self.hidden_activation = hidden_activation | |
self.head_dim = self.hidden_size // self.num_attention_heads | |
self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads | |
if self.num_key_value_heads > self.num_attention_heads: | |
raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`") | |
self.cross_attn_layers = cross_attn_layers | |
self.self_attn_layers = self_attn_layers | |
self.global_attn_layers = global_attn_layers | |
self.attention_dropout = attention_dropout | |
self.attention_bias = attention_bias | |
self.w_init_variance_scale = w_init_variance_scale | |
self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers | |
self.init_std = init_std | |
self.tie_word_embeddings = tie_word_embeddings | |
self.aux_heads = aux_heads | |
self.encoder_hidden_size = encoder_hidden_size | |
self.iteration_count = iteration_count | |
self.causal = causal | |
self.query_token_count = query_token_count | |
super().__init__( | |
pad_token_id=pad_token_id, | |
bos_token_id=bos_token_id, | |
eos_token_id=eos_token_id, | |
**kwargs, | |
) | |
def layers_block_type(self): | |
return (self.block_types * 100)[: self.num_hidden_layers] | |
TOTAL_TOKENS = 65536 | |
TOKEN_OFFSET = 3 # Pad, eos, bos | |
SPECIAL_TOKENS = 253 | |
TOTAL_VOCAB_SIZE = TOTAL_TOKENS + TOKEN_OFFSET + SPECIAL_TOKENS | |
LANGUAGE_MAP = { | |
'af': 0, | |
'am': 1, | |
'ar': 2, | |
'as': 3, | |
'az': 4, | |
'be': 5, | |
'bg': 6, | |
'bn': 7, | |
'br': 8, | |
'bs': 9, | |
'ca': 10, | |
'cs': 11, | |
'cy': 12, | |
'da': 13, | |
'de': 14, | |
'el': 15, | |
'en': 16, | |
'eo': 17, | |
'es': 18, | |
'et': 19, | |
'eu': 20, | |
'fa': 21, | |
'fi': 22, | |
'fr': 23, | |
'fy': 24, | |
'ga': 25, | |
'gd': 26, | |
'gl': 27, | |
'gu': 28, | |
'ha': 29, | |
'he': 30, | |
'hi': 31, | |
'hr': 32, | |
'hu': 33, | |
'hy': 34, | |
'id': 35, | |
'is': 36, | |
'it': 37, | |
'ja': 38, | |
'jv': 39, | |
'ka': 40, | |
'kk': 41, | |
'km': 42, | |
'kn': 43, | |
'ko': 44, | |
'ku': 45, | |
'ky': 46, | |
'la': 47, | |
'lo': 48, | |
'lt': 49, | |
'lv': 50, | |
'mg': 51, | |
'mk': 52, | |
'ml': 53, | |
'mn': 54, | |
'mr': 55, | |
'ms': 56, | |
'my': 57, | |
'ne': 58, | |
'nl': 59, | |
'no': 60, | |
'om': 61, | |
'or': 62, | |
'pa': 63, | |
'pl': 64, | |
'ps': 65, | |
'pt': 66, | |
'ro': 67, | |
'ru': 68, | |
'sa': 69, | |
'sd': 70, | |
'si': 71, | |
'sk': 72, | |
'sl': 73, | |
'so': 74, | |
'sq': 75, | |
'sr': 76, | |
'su': 77, | |
'sv': 78, | |
'sw': 79, | |
'ta': 80, | |
'te': 81, | |
'th': 82, | |
'tl': 83, | |
'tr': 84, | |
'ug': 85, | |
'uk': 86, | |
'ur': 87, | |
'uz': 88, | |
'vi': 89, | |
'xh': 90, | |
'yi': 91, | |
'zh': 92, | |
"_math": 93 | |
} |