from dataclasses import dataclass import torch from transformers import PretrainedConfig from transformers.utils import ModelOutput class SuryaOCRConfig(PretrainedConfig): model_type = "vision-encoder-decoder" is_composition = True def __init__(self, **kwargs): super().__init__(**kwargs) encoder_config = kwargs.pop("encoder") decoder_config = kwargs.pop("decoder") self.encoder = encoder_config self.decoder = decoder_config self.is_encoder_decoder = True if isinstance(decoder_config, dict): self.decoder_start_token_id = decoder_config["bos_token_id"] self.pad_token_id = decoder_config["pad_token_id"] self.eos_token_id = decoder_config["eos_token_id"] else: self.decoder_start_token_id = decoder_config.bos_token_id self.pad_token_id = decoder_config.pad_token_id self.eos_token_id = decoder_config.eos_token_id class DonutSwinConfig(PretrainedConfig): model_type = "donut-swin" attribute_map = { "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers", } def __init__( self, image_size=(256, 896), patch_size=4, num_channels=3, embed_dim=128, depths=[2, 2, 14, 2], num_heads=[4, 8, 16, 32], num_kv_heads=[1, 2, 4, 8], window_size=7, mlp_ratio=4.0, qkv_bias=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, drop_path_rate=0.1, hidden_act="gelu", use_absolute_embeddings=True, initializer_range=0.02, layer_norm_eps=1e-5, encoder_length=256, **kwargs, ): super().__init__(**kwargs) self.image_size = image_size self.patch_size = patch_size self.num_channels = num_channels self.embed_dim = embed_dim self.depths = depths self.num_layers = len(depths) self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.window_size = window_size self.mlp_ratio = mlp_ratio self.qkv_bias = qkv_bias self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.drop_path_rate = drop_path_rate self.hidden_act = hidden_act self.use_absolute_embeddings = use_absolute_embeddings self.layer_norm_eps = layer_norm_eps self.initializer_range = initializer_range # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel # this indicates the channel dimension after the last stage of the model self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) self.encoder_length = encoder_length class SuryaOCRDecoderConfig(PretrainedConfig): model_type = "surya_ocr" def __init__( self, num_hidden_layers=10, vocab_size=65792, hidden_size=1024, intermediate_size=4 * 1024, num_attention_heads=16, lru_width=None, attention_window_size=16, conv1d_width=4, logits_soft_cap=30.0, rms_norm_eps=1e-6, use_cache=True, pad_token_id=0, eos_token_id=1, bos_token_id=1, hidden_activation="gelu_pytorch_tanh", rope_theta=10000.0, block_types=("attention",), cross_attn_layers=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), self_attn_layers=(0, 1, 3, 5, 7, 9), global_attn_layers=(0, 1, 3, 5, 7, 9), attention_dropout=0.0, num_key_value_heads=2, attention_bias=False, w_init_variance_scale=0.01, init_std=0.02, tie_word_embeddings=False, aux_heads=0, # How many n-token-ahead heads to add encoder_hidden_size=1024, causal=False, **kwargs, ): self.num_hidden_layers = num_hidden_layers self.vocab_size = vocab_size self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_attention_heads = num_attention_heads self.lru_width = lru_width if lru_width is not None else hidden_size self.attention_window_size = attention_window_size self.conv1d_width = conv1d_width self.logits_soft_cap = logits_soft_cap self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.block_types = list(block_types) self.hidden_activation = hidden_activation self.head_dim = self.hidden_size // self.num_attention_heads self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads if self.num_key_value_heads > self.num_attention_heads: raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`") self.cross_attn_layers = cross_attn_layers self.self_attn_layers = self_attn_layers self.global_attn_layers = global_attn_layers self.attention_dropout = attention_dropout self.attention_bias = attention_bias self.w_init_variance_scale = w_init_variance_scale self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers self.init_std = init_std self.tie_word_embeddings = tie_word_embeddings self.aux_heads = aux_heads self.encoder_hidden_size = encoder_hidden_size self.causal = causal super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs, ) @property def layers_block_type(self): return (self.block_types * 100)[: self.num_hidden_layers] class SuryaOCRTextEncoderConfig(PretrainedConfig): model_type = "surya_ocr" def __init__( self, num_hidden_layers=10, vocab_size=65792, hidden_size=1024, intermediate_size=4 * 1024, num_attention_heads=16, lru_width=None, attention_window_size=16, conv1d_width=4, logits_soft_cap=30.0, rms_norm_eps=1e-6, use_cache=True, pad_token_id=0, eos_token_id=1, bos_token_id=1, hidden_activation="gelu_pytorch_tanh", rope_theta=10000.0, block_types=("attention",), cross_attn_layers=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), self_attn_layers=(0, 1, 3, 5, 7, 9), global_attn_layers=(0, 1, 3, 5, 7, 9), attention_dropout=0.0, num_key_value_heads=2, attention_bias=False, w_init_variance_scale=0.01, init_std=0.02, tie_word_embeddings=False, aux_heads=0, # How many n-token-ahead heads to add encoder_hidden_size=1024, iteration_count=1, causal=False, query_token_count=128, **kwargs, ): self.num_hidden_layers = num_hidden_layers self.vocab_size = vocab_size self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_attention_heads = num_attention_heads self.lru_width = lru_width if lru_width is not None else hidden_size self.attention_window_size = attention_window_size self.conv1d_width = conv1d_width self.logits_soft_cap = logits_soft_cap self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.block_types = list(block_types) self.hidden_activation = hidden_activation self.head_dim = self.hidden_size // self.num_attention_heads self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads if self.num_key_value_heads > self.num_attention_heads: raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`") self.cross_attn_layers = cross_attn_layers self.self_attn_layers = self_attn_layers self.global_attn_layers = global_attn_layers self.attention_dropout = attention_dropout self.attention_bias = attention_bias self.w_init_variance_scale = w_init_variance_scale self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers self.init_std = init_std self.tie_word_embeddings = tie_word_embeddings self.aux_heads = aux_heads self.encoder_hidden_size = encoder_hidden_size self.iteration_count = iteration_count self.causal = causal self.query_token_count = query_token_count super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs, ) @property def layers_block_type(self): return (self.block_types * 100)[: self.num_hidden_layers] TOTAL_TOKENS = 65536 TOKEN_OFFSET = 3 # Pad, eos, bos SPECIAL_TOKENS = 253 TOTAL_VOCAB_SIZE = TOTAL_TOKENS + TOKEN_OFFSET + SPECIAL_TOKENS LANGUAGE_MAP = { 'af': 0, 'am': 1, 'ar': 2, 'as': 3, 'az': 4, 'be': 5, 'bg': 6, 'bn': 7, 'br': 8, 'bs': 9, 'ca': 10, 'cs': 11, 'cy': 12, 'da': 13, 'de': 14, 'el': 15, 'en': 16, 'eo': 17, 'es': 18, 'et': 19, 'eu': 20, 'fa': 21, 'fi': 22, 'fr': 23, 'fy': 24, 'ga': 25, 'gd': 26, 'gl': 27, 'gu': 28, 'ha': 29, 'he': 30, 'hi': 31, 'hr': 32, 'hu': 33, 'hy': 34, 'id': 35, 'is': 36, 'it': 37, 'ja': 38, 'jv': 39, 'ka': 40, 'kk': 41, 'km': 42, 'kn': 43, 'ko': 44, 'ku': 45, 'ky': 46, 'la': 47, 'lo': 48, 'lt': 49, 'lv': 50, 'mg': 51, 'mk': 52, 'ml': 53, 'mn': 54, 'mr': 55, 'ms': 56, 'my': 57, 'ne': 58, 'nl': 59, 'no': 60, 'om': 61, 'or': 62, 'pa': 63, 'pl': 64, 'ps': 65, 'pt': 66, 'ro': 67, 'ru': 68, 'sa': 69, 'sd': 70, 'si': 71, 'sk': 72, 'sl': 73, 'so': 74, 'sq': 75, 'sr': 76, 'su': 77, 'sv': 78, 'sw': 79, 'ta': 80, 'te': 81, 'th': 82, 'tl': 83, 'tr': 84, 'ug': 85, 'uk': 86, 'ur': 87, 'uz': 88, 'vi': 89, 'xh': 90, 'yi': 91, 'zh': 92, "_math": 93 }