File size: 1,088 Bytes
8eb6782
ca39bed
8eb6782
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from transformers import PretrainedConfig
from torchscale.architecture.config import EncoderConfig

class ViVQAConfig(PretrainedConfig):
    model_type = "vivqa"
    
    def __init__(
        self,
        drop_path_rate: float = 0.0,
        mlp_ratio: float = 4.0,
        encoder_layers: int = 4,
        encoder_attention_heads: int = 4,
        multiway: bool = True,
        layernorm_embedding: bool = False,
        normalize_output: bool = True,
        no_output_layer: bool = True,
        encoder_embed_dim: int = 768,
         **kwargs
    ):
        args = EncoderConfig(
            multiway=multiway, 
            layernorm_embedding=layernorm_embedding, normalize_output=normalize_output, no_output_layer=no_output_layer, 
            drop_path_rate=drop_path_rate, encoder_embed_dim=768, encoder_attention_heads=encoder_attention_heads, 
            encoder_ffn_embed_dim=int(768 * mlp_ratio), encoder_layers=encoder_layers,
        )
        for key, value in args.__dict__.items():
            setattr(self, key, value)
            
        super().__init__(**kwargs)