vivqa-model / configuration_vivqa.py
ngocson2002's picture
Update model
3fdd9ce
from transformers import PretrainedConfig
from torchscale.architecture.config import EncoderConfig
class ViVQAConfig(PretrainedConfig):
model_type = "vivqa"
def __init__(
self,
drop_path_rate: float = 0.0,
mlp_ratio: float = 4.0,
encoder_layers: int = 6,
encoder_attention_heads: int = 6,
multiway: bool = True,
layernorm_embedding: bool = False,
normalize_output: bool = True,
no_output_layer: bool = True,
encoder_embed_dim: int = 768,
**kwargs
):
args = EncoderConfig(
multiway=multiway,
layernorm_embedding=layernorm_embedding, normalize_output=normalize_output, no_output_layer=no_output_layer,
drop_path_rate=drop_path_rate, encoder_embed_dim=768, encoder_attention_heads=encoder_attention_heads,
encoder_ffn_embed_dim=int(768 * mlp_ratio), encoder_layers=encoder_layers,
)
for key, value in args.__dict__.items():
setattr(self, key, value)
super().__init__(**kwargs)