|
|
|
from typing import List |
|
from transformers import PretrainedConfig, T5Config |
|
|
|
|
|
class SwinVilmaConfig(PretrainedConfig): |
|
model_type = "swin_vilma" |
|
|
|
def __init__( |
|
self, |
|
patch_size: int = 4, |
|
in_chans: int = 3, |
|
embed_dim: int = 96, |
|
depths: List[int] = [2, 2, 18, 2], |
|
num_heads: List[int] = [3, 6, 12, 24], |
|
window_size: int = 24, |
|
mlp_ratio: float = 4.0, |
|
qkv_bias: bool = True, |
|
ape: bool = False, |
|
patch_norm: bool = True, |
|
pretrained_window_sizes: List[int] = [0, 0, 0, 0], |
|
vl_cross_attn_layers: List[int] = [3], |
|
vl_alpha: float = 0.5, |
|
lm_d_model: int = 768, |
|
text_embedder: str = "t5-base", |
|
downsampling_method: str = "merge_attention_v3", |
|
vision_name: str = "swin_small_patch4_window7_224_22k", |
|
image_size: List[int] = [1536, 768], |
|
drop_path_rate: float = 0.3, |
|
drop_rate: float = 0.0, |
|
resume_from: str = "", |
|
use_checkpoint: bool = False, |
|
do_shift: bool = True, |
|
input_type: str = "rgb", |
|
vl_learned_ape: bool = True, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.patch_size = patch_size |
|
self.in_chans = in_chans |
|
self.embed_dim = embed_dim |
|
self.depths = depths |
|
self.num_heads = num_heads |
|
self.window_size = window_size |
|
self.mlp_ratio = mlp_ratio |
|
self.qkv_bias = qkv_bias |
|
self.ape = ape |
|
self.patch_norm = patch_norm |
|
self.pretrained_window_sizes = pretrained_window_sizes |
|
self.vl_cross_attn_layers = vl_cross_attn_layers |
|
self.vl_alpha = vl_alpha |
|
self.lm_d_model = lm_d_model |
|
self.text_embedder = text_embedder |
|
self.downsampling_method = downsampling_method |
|
self.vision_name = vision_name |
|
self.image_size = image_size |
|
self.drop_path_rate = drop_path_rate |
|
self.drop_rate = drop_rate |
|
self.resume_from = resume_from |
|
self.use_checkpoint = use_checkpoint |
|
self.do_shift = do_shift |
|
self.input_type = input_type |
|
self.vl_learned_ape = vl_learned_ape |
|
|
|
class VisFocusConfig(PretrainedConfig): |
|
model_type = "visfocus" |
|
|
|
def __init__( |
|
self, |
|
initializer_factor: float = 1.0, |
|
initializer_range: float = 0.02, |
|
max_seq_length: int = 2048, |
|
generate_max_new_tokens_len: int = 256, |
|
model_name_or_path: str = "", |
|
variant: str = "vf-base", |
|
image_size: List[int] = [1536, 768], |
|
seed: int = 42, |
|
do_lower_case: bool = True, |
|
hidden_dropout_prob: float = .1, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.initializer_factor = initializer_factor |
|
self.initializer_range = initializer_range |
|
self.max_seq_length = max_seq_length |
|
self.generate_max_new_tokens_len = generate_max_new_tokens_len |
|
self.model_name_or_path = model_name_or_path |
|
self.variant = variant |
|
self.image_size = image_size |
|
self.seed = seed |
|
self.do_lower_case = do_lower_case |
|
self.hidden_dropout_prob = hidden_dropout_prob |
|
self.vision_config = SwinVilmaConfig() |
|
self.lm_config = T5Config() |
|
|