Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442) | |
# Github source: https://github.com/microsoft/unilm/tree/master/beit3 | |
# Copyright (c) 2023 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# --------------------------------------------------------' | |
import math | |
import torch | |
import torch.nn as nn | |
from timm.models.layers import trunc_normal_ as __call_trunc_normal_ | |
from torchscale.architecture.config import EncoderConfig | |
from torchscale.model.BEiT3 import BEiT3 | |
def trunc_normal_(tensor, mean=0.0, std=1.0): | |
__call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) | |
def _get_base_config( | |
img_size=224, | |
patch_size=16, | |
drop_path_rate=0, | |
checkpoint_activations=None, | |
mlp_ratio=4, | |
vocab_size=64010, | |
**kwargs | |
): | |
return EncoderConfig( | |
img_size=img_size, | |
patch_size=patch_size, | |
vocab_size=vocab_size, | |
multiway=True, | |
layernorm_embedding=False, | |
normalize_output=True, | |
no_output_layer=True, | |
drop_path_rate=drop_path_rate, | |
encoder_embed_dim=768, | |
encoder_attention_heads=12, | |
encoder_ffn_embed_dim=int(768 * mlp_ratio), | |
encoder_layers=12, | |
checkpoint_activations=checkpoint_activations, | |
) | |
def _get_large_config( | |
img_size=224, | |
patch_size=16, | |
drop_path_rate=0, | |
checkpoint_activations=None, | |
mlp_ratio=4, | |
vocab_size=64010, | |
**kwargs | |
): | |
return EncoderConfig( | |
img_size=img_size, | |
patch_size=patch_size, | |
vocab_size=vocab_size, | |
multiway=True, | |
layernorm_embedding=False, | |
normalize_output=True, | |
no_output_layer=True, | |
drop_path_rate=drop_path_rate, | |
encoder_embed_dim=1024, | |
encoder_attention_heads=16, | |
encoder_ffn_embed_dim=int(1024 * mlp_ratio), | |
encoder_layers=24, | |
checkpoint_activations=checkpoint_activations, | |
) | |
class BEiT3Wrapper(nn.Module): | |
def __init__(self, args, **kwargs): | |
super().__init__() | |
self.args = args | |
self.beit3 = BEiT3(args) | |
self.apply(self._init_weights) | |
def fix_init_weight(self): | |
def rescale(param, layer_id): | |
param.div_(math.sqrt(2.0 * layer_id)) | |
for layer_id, layer in enumerate(self.blocks): | |
rescale(layer.attn.proj.weight.data, layer_id + 1) | |
rescale(layer.mlp.fc2.weight.data, layer_id + 1) | |
def get_num_layers(self): | |
return self.beit3.encoder.num_layers | |
def no_weight_decay(self): | |
return { | |
'pos_embed', | |
'cls_token', | |
'beit3.encoder.embed_positions.A.weight', | |
'beit3.vision_embed.cls_token', | |
'logit_scale', | |
} | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
trunc_normal_(m.weight, std=0.02) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |