import os import re import math import torch import torch.nn as nn from .clip_encoder import CLIPVisionTower from .eva_clip_encoder import EvaClipVisionTower from .siglip_encoder import SiglipVisionTower from .google_siglip_encoder import GoogleSiglipVisionTower from llava.model.utils import LayerNorm from .qformer import BertConfig, BertLMHeadModel from .resampler import Resampler, TokenCompressor from torch.nn.init import trunc_normal_ def build_vision_tower(vision_tower_cfg, **kwargs): vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) # is_absolute_path_exists = os.path.exists(vision_tower) if vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: vision_tower = CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) elif vision_tower.startswith("eva"): vision_tower = EvaClipVisionTower(vision_tower, args=vision_tower_cfg) elif vision_tower.startswith("google/siglip"): vision_tower = GoogleSiglipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) elif 'HuggingFaceM4/siglip' in vision_tower: vision_tower = SiglipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) else: raise ValueError(f'Unknown vision tower: {vision_tower}') return vision_tower def build_Qformer(num_query_token, vision_width, extra_num_query_token=64, cross_attention_freq=2): ln_vision = LayerNorm(vision_width) encoder_config = BertConfig.from_pretrained("./model/bert-base-uncased") encoder_config.encoder_width = vision_width # insert cross-attention layer every other block encoder_config.add_cross_attention = True encoder_config.cross_attention_freq = cross_attention_freq encoder_config.query_length = num_query_token Qformer = BertLMHeadModel(config=encoder_config) query_tokens = nn.Parameter( torch.zeros(1, num_query_token, encoder_config.hidden_size) ) query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) Qformer.cls = None Qformer.bert.embeddings.word_embeddings = None Qformer.bert.embeddings.position_embeddings = None for layer in Qformer.bert.encoder.layer: layer.output = None layer.intermediate = None return Qformer, ln_vision, query_tokens #TODO: remove the vision_width here def build_adapter_module(cfg, vision_width): return AdapterModule(cfg, vision_width) class IdentityMap(nn.Module): def __init__(self): super().__init__() def forward(self, x, *args, **kwargs): return x class AdapterModule(nn.Module): def __init__(self, config, vision_width): super().__init__() self.adapter_name = config.adapter_module_name self.config = config self.output_dim = vision_width if 'perceiver' in self.adapter_name: from flash_perceiver import Perceiver self.adapter = Perceiver( input_dim=vision_width, depth=6, output_dim=vision_width, num_latents=self.config.num_query_token, latent_dim=1024, cross_heads=1, cross_head_dim=128, cross_rotary_emb_dim=0, cross_attn_dropout=0.0, latent_heads=8, latent_head_dim=128, latent_rotary_emb_dim=0, latent_attn_dropout=0.0, weight_tie_layers=False, gated_mlp=True, self_per_cross_attn=1, num_zero_tokens=None, use_flash_attn=True, ) elif 'naive_resampler' in self.adapter_name: assert math.sqrt(self.config.num_query_token) ** 2 == self.config.num_query_token, 'num of query need to be a square number' self.adapter = Resampler( grid_size=int(math.sqrt(self.config.num_query_token)), embed_dim=vision_width, num_heads=8, ) elif 'qformer' in self.adapter_name: Qformer, ln_vision, query_tokens = build_Qformer( self.config.num_query_token, vision_width) self.adapter = Qformer self.ln_vision = ln_vision self.query_tokens = query_tokens self.output_dim = Qformer.config.hidden_size elif 'none' in self.adapter_name: self.adapter = IdentityMap() self.is_loaded = False if 'compress_token' in self.adapter_name: match = re.search(r'\d+$', self.adapter_name) self.token_compressor = TokenCompressor( num_compressed_token=int(match.group()), embed_dim=self.config.hidden_size, num_heads=8, ) if 'v1' in self.adapter_name: self.compress_version = 'v1' else: self.compress_version = 'v0' # self.ln_vision = LayerNorm(self.config.vision_in_dim) self.frame_position_encoding = nn.Embedding( config.max_num_segments, self.output_dim, ) self.adapter.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, (nn.Linear, nn.Embedding)): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, image_features, frame_ids): if 'perceiver' in self.adapter_name: adapted_image_features = self.adapter(image_features, return_embeddings=True) elif 'naive_resampler' in self.adapter_name: adapted_image_features = self.adapter(image_features) elif 'qformer' in self.adapter_name: image_features = self.ln_vision(image_features) query_tokens = self.query_tokens.expand(image_features.shape[0], -1, -1) attn_mask = torch.ones(image_features.size()[:-1], dtype=torch.long).to(image_features.device) adapted_image_features = self.adapter.bert( query_embeds=query_tokens, encoder_hidden_states=image_features, encoder_attention_mask=attn_mask, return_dict=True ).last_hidden_state elif 'none' in self.adapter_name: adapted_image_features = self.adapter(image_features) frame_embeddings = self.frame_position_encoding(frame_ids).unsqueeze(-2) adapted_image_features += frame_embeddings return adapted_image_features # TODO: addhoc func, rewrite it in the future def compress_token_per_img(self, batch_image_features): if 'compress_token' not in self.adapter_name: return batch_image_features compressed_features = [] for image_features in batch_image_features: # image_features [num_frames, tokens, C] # handle non image cases(in that case, image_patch maybe smaller than num_compressed_token) if image_features.shape[1] < self.token_compressor.num_compressed_token: compressed_features.append(image_features) else: compressed_features.append(self.token_compressor(image_features, compress_version=self.compress_version)) return compressed_features def load_model(self): if self.is_loaded: return if getattr(self.config, 'adapter_module_path', None): checkpoint = torch.load(self.config.adapter_module_path, map_location="cpu") def get_w(weights, keyword): return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword + '.' in k} def get_variable_frame_encoding_w(model_weights, load_weights): keyword = 'frame_position_encoding' model_len = model_weights.shape[0] load_weights_f_encoding = get_w(load_weights, keyword) load_len = load_weights_f_encoding['weight'].shape[0] if model_len <= load_len: value = load_weights_f_encoding['weight'][:model_len] else: value = model_weights.clone().cpu() value[:load_len] = load_weights_f_encoding['weight'] return value if 'qformer' in self.adapter_name and ('projector.bin' not in self.config.adapter_module_path): state_dict = checkpoint["model"] self.adapter.load_state_dict(get_w(state_dict, 'Qformer')) self.ln_vision.load_state_dict(get_w(state_dict, 'ln_vision')) self.load_state_dict({'query_tokens': state_dict['query_tokens']}, strict=False) if getattr(self.config, 'pretrain_mm_mlp_adapter', None): mm_projector_weights = torch.load(self.config.pretrain_mm_mlp_adapter, map_location='cpu') frame_encoding_weight = get_variable_frame_encoding_w(self.frame_position_encoding.weight, mm_projector_weights) self.frame_position_encoding.load_state_dict({'weight': frame_encoding_weight}) else: frame_encoding_weight = get_variable_frame_encoding_w(self.frame_position_encoding.weight, checkpoint) for k in checkpoint.keys(): if 'frame_position_encoding' in k: checkpoint[k] = frame_encoding_weight self.load_state_dict(get_w(checkpoint, 'adapter_module')) else: # no pertrain weight, use initalization return def freeze_adapter_module(self, freeze_flag): if freeze_flag: for name, p in self.named_parameters(): p.requires_grad = False else: for name, p in self.named_parameters(): p.requires_grad = True if 'naive_resampler' in self.adapter_name: for name, p in self.named_parameters(): if 'pos_embed' in name: p.requires_grad = False