File size: 3,255 Bytes
625d13e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

from typing import List
from transformers import PretrainedConfig, T5Config


class SwinVilmaConfig(PretrainedConfig):
    model_type = "swin_vilma"

    def __init__(
        self,
        patch_size: int = 4,
        in_chans: int = 3,
        embed_dim: int = 96,
        depths: List[int] = [2, 2, 18, 2],
        num_heads: List[int] = [3, 6, 12, 24],
        window_size: int = 24,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        ape: bool = False,
        patch_norm: bool = True,
        pretrained_window_sizes: List[int] = [0, 0, 0, 0],
        vl_cross_attn_layers: List[int] = [3],
        vl_alpha: float = 0.5,
        lm_d_model: int = 768,
        text_embedder: str = "t5-base",
        downsampling_method: str = "merge_attention_v3",
        vision_name: str = "swin_small_patch4_window7_224_22k",
        image_size: List[int] = [1536, 768],
        drop_path_rate: float = 0.3,
        drop_rate: float = 0.0,
        resume_from: str = "",
        use_checkpoint: bool = False,
        do_shift: bool = True,
        input_type: str = "rgb",
        vl_learned_ape: bool = True,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.embed_dim = embed_dim
        self.depths = depths
        self.num_heads = num_heads
        self.window_size = window_size
        self.mlp_ratio = mlp_ratio
        self.qkv_bias = qkv_bias
        self.ape = ape
        self.patch_norm = patch_norm
        self.pretrained_window_sizes = pretrained_window_sizes
        self.vl_cross_attn_layers = vl_cross_attn_layers
        self.vl_alpha = vl_alpha
        self.lm_d_model = lm_d_model
        self.text_embedder = text_embedder
        self.downsampling_method = downsampling_method
        self.vision_name = vision_name
        self.image_size = image_size
        self.drop_path_rate = drop_path_rate
        self.drop_rate = drop_rate
        self.resume_from = resume_from
        self.use_checkpoint = use_checkpoint
        self.do_shift = do_shift
        self.input_type = input_type
        self.vl_learned_ape = vl_learned_ape

class VisFocusConfig(PretrainedConfig):
    model_type = "visfocus"

    def __init__(
        self,
        initializer_factor: float = 1.0,
        initializer_range: float = 0.02,
        max_seq_length: int = 2048,
        generate_max_new_tokens_len: int = 256,
        model_name_or_path: str = "",
        variant: str = "vf-base",
        image_size: List[int] = [1536, 768],
        seed: int = 42,
        do_lower_case: bool = True,
        hidden_dropout_prob: float = .1,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.initializer_factor = initializer_factor
        self.initializer_range = initializer_range
        self.max_seq_length = max_seq_length
        self.generate_max_new_tokens_len = generate_max_new_tokens_len
        self.model_name_or_path = model_name_or_path
        self.variant = variant
        self.image_size = image_size
        self.seed = seed
        self.do_lower_case = do_lower_case
        self.hidden_dropout_prob = hidden_dropout_prob
        self.vision_config = SwinVilmaConfig()
        self.lm_config = T5Config()