''' * Adapted from BLIP (https://github.com/salesforce/BLIP) ''' import warnings warnings.filterwarnings("ignore") import torch import os from urllib.parse import urlparse from timm.models.hub import download_cached_file from transformers import BertTokenizer from .vit import VisionTransformer, interpolate_pos_embed def init_tokenizer(): tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tokenizer.add_special_tokens({'bos_token':'[DEC]'}) tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']}) tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] return tokenizer def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0): assert vit in ['base', 'large'], "vit parameter must be base or large" if vit=='base': vision_width = 768 visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12, num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, drop_path_rate=0 or drop_path_rate ) elif vit=='large': vision_width = 1024 visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24, num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, drop_path_rate=0.1 or drop_path_rate ) return visual_encoder, vision_width def is_url(url_or_filename): parsed = urlparse(url_or_filename) return parsed.scheme in ("http", "https") def load_checkpoint(model,url_or_filename): if is_url(url_or_filename): cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) checkpoint = torch.load(cached_file, map_location='cpu') elif os.path.isfile(url_or_filename): checkpoint = torch.load(url_or_filename, map_location='cpu') else: raise RuntimeError('checkpoint url or path is invalid') state_dict = checkpoint['model'] state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) if 'visual_encoder_m.pos_embed' in model.state_dict().keys(): state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'], model.visual_encoder_m) for key in model.state_dict().keys(): if key in state_dict.keys(): if state_dict[key].shape!=model.state_dict()[key].shape: print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape) del state_dict[key] msg = model.load_state_dict(state_dict,strict=False) print('load checkpoint from %s'%url_or_filename) return model,msg