Spaces:
Running
Running
''' | |
* Adapted from BLIP (https://github.com/salesforce/BLIP) | |
''' | |
import warnings | |
warnings.filterwarnings("ignore") | |
import torch | |
import os | |
from urllib.parse import urlparse | |
from timm.models.hub import download_cached_file | |
from transformers import BertTokenizer | |
from .vit import VisionTransformer, interpolate_pos_embed | |
def init_tokenizer(): | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
tokenizer.add_special_tokens({'bos_token':'[DEC]'}) | |
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']}) | |
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] | |
return tokenizer | |
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0): | |
assert vit in ['base', 'large'], "vit parameter must be base or large" | |
if vit=='base': | |
vision_width = 768 | |
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12, | |
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, | |
drop_path_rate=0 or drop_path_rate | |
) | |
elif vit=='large': | |
vision_width = 1024 | |
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24, | |
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, | |
drop_path_rate=0.1 or drop_path_rate | |
) | |
return visual_encoder, vision_width | |
def is_url(url_or_filename): | |
parsed = urlparse(url_or_filename) | |
return parsed.scheme in ("http", "https") | |
def load_checkpoint(model,url_or_filename): | |
if is_url(url_or_filename): | |
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) | |
checkpoint = torch.load(cached_file, map_location='cpu') | |
elif os.path.isfile(url_or_filename): | |
checkpoint = torch.load(url_or_filename, map_location='cpu') | |
else: | |
raise RuntimeError('checkpoint url or path is invalid') | |
state_dict = checkpoint['model'] | |
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) | |
if 'visual_encoder_m.pos_embed' in model.state_dict().keys(): | |
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'], | |
model.visual_encoder_m) | |
for key in model.state_dict().keys(): | |
if key in state_dict.keys(): | |
if state_dict[key].shape!=model.state_dict()[key].shape: | |
print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape) | |
del state_dict[key] | |
msg = model.load_state_dict(state_dict,strict=False) | |
print('load checkpoint from %s'%url_or_filename) | |
return model,msg | |