File size: 1,088 Bytes
8eb6782 ca39bed 8eb6782 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
from transformers import PretrainedConfig
from torchscale.architecture.config import EncoderConfig
class ViVQAConfig(PretrainedConfig):
model_type = "vivqa"
def __init__(
self,
drop_path_rate: float = 0.0,
mlp_ratio: float = 4.0,
encoder_layers: int = 4,
encoder_attention_heads: int = 4,
multiway: bool = True,
layernorm_embedding: bool = False,
normalize_output: bool = True,
no_output_layer: bool = True,
encoder_embed_dim: int = 768,
**kwargs
):
args = EncoderConfig(
multiway=multiway,
layernorm_embedding=layernorm_embedding, normalize_output=normalize_output, no_output_layer=no_output_layer,
drop_path_rate=drop_path_rate, encoder_embed_dim=768, encoder_attention_heads=encoder_attention_heads,
encoder_ffn_embed_dim=int(768 * mlp_ratio), encoder_layers=encoder_layers,
)
for key, value in args.__dict__.items():
setattr(self, key, value)
super().__init__(**kwargs) |