visfocus-base-docvqa / configuration_visfocus.py
ofirab's picture
Upload config
625d13e verified
from typing import List
from transformers import PretrainedConfig, T5Config
class SwinVilmaConfig(PretrainedConfig):
model_type = "swin_vilma"
def __init__(
self,
patch_size: int = 4,
in_chans: int = 3,
embed_dim: int = 96,
depths: List[int] = [2, 2, 18, 2],
num_heads: List[int] = [3, 6, 12, 24],
window_size: int = 24,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
ape: bool = False,
patch_norm: bool = True,
pretrained_window_sizes: List[int] = [0, 0, 0, 0],
vl_cross_attn_layers: List[int] = [3],
vl_alpha: float = 0.5,
lm_d_model: int = 768,
text_embedder: str = "t5-base",
downsampling_method: str = "merge_attention_v3",
vision_name: str = "swin_small_patch4_window7_224_22k",
image_size: List[int] = [1536, 768],
drop_path_rate: float = 0.3,
drop_rate: float = 0.0,
resume_from: str = "",
use_checkpoint: bool = False,
do_shift: bool = True,
input_type: str = "rgb",
vl_learned_ape: bool = True,
**kwargs
):
super().__init__(**kwargs)
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.depths = depths
self.num_heads = num_heads
self.window_size = window_size
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.ape = ape
self.patch_norm = patch_norm
self.pretrained_window_sizes = pretrained_window_sizes
self.vl_cross_attn_layers = vl_cross_attn_layers
self.vl_alpha = vl_alpha
self.lm_d_model = lm_d_model
self.text_embedder = text_embedder
self.downsampling_method = downsampling_method
self.vision_name = vision_name
self.image_size = image_size
self.drop_path_rate = drop_path_rate
self.drop_rate = drop_rate
self.resume_from = resume_from
self.use_checkpoint = use_checkpoint
self.do_shift = do_shift
self.input_type = input_type
self.vl_learned_ape = vl_learned_ape
class VisFocusConfig(PretrainedConfig):
model_type = "visfocus"
def __init__(
self,
initializer_factor: float = 1.0,
initializer_range: float = 0.02,
max_seq_length: int = 2048,
generate_max_new_tokens_len: int = 256,
model_name_or_path: str = "",
variant: str = "vf-base",
image_size: List[int] = [1536, 768],
seed: int = 42,
do_lower_case: bool = True,
hidden_dropout_prob: float = .1,
**kwargs
):
super().__init__(**kwargs)
self.initializer_factor = initializer_factor
self.initializer_range = initializer_range
self.max_seq_length = max_seq_length
self.generate_max_new_tokens_len = generate_max_new_tokens_len
self.model_name_or_path = model_name_or_path
self.variant = variant
self.image_size = image_size
self.seed = seed
self.do_lower_case = do_lower_case
self.hidden_dropout_prob = hidden_dropout_prob
self.vision_config = SwinVilmaConfig()
self.lm_config = T5Config()