|
import os |
|
import re |
|
import math |
|
import torch |
|
import torch.nn as nn |
|
from .clip_encoder import CLIPVisionTower |
|
from .eva_clip_encoder import EvaClipVisionTower |
|
from .siglip_encoder import SiglipVisionTower |
|
from .google_siglip_encoder import GoogleSiglipVisionTower |
|
from llava.model.utils import LayerNorm |
|
from .qformer import BertConfig, BertLMHeadModel |
|
from .resampler import Resampler, TokenCompressor |
|
from torch.nn.init import trunc_normal_ |
|
|
|
|
|
|
|
|
|
|
|
def build_vision_tower(vision_tower_cfg, **kwargs): |
|
vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) |
|
|
|
if vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: |
|
vision_tower = CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) |
|
elif vision_tower.startswith("eva"): |
|
vision_tower = EvaClipVisionTower(vision_tower, args=vision_tower_cfg) |
|
elif vision_tower.startswith("google/siglip"): |
|
vision_tower = GoogleSiglipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) |
|
elif 'HuggingFaceM4/siglip' in vision_tower: |
|
vision_tower = SiglipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) |
|
else: |
|
raise ValueError(f'Unknown vision tower: {vision_tower}') |
|
|
|
return vision_tower |
|
|
|
|
|
|
|
def build_Qformer(num_query_token, vision_width, extra_num_query_token=64, cross_attention_freq=2): |
|
ln_vision = LayerNorm(vision_width) |
|
encoder_config = BertConfig.from_pretrained("./model/bert-base-uncased") |
|
encoder_config.encoder_width = vision_width |
|
|
|
encoder_config.add_cross_attention = True |
|
encoder_config.cross_attention_freq = cross_attention_freq |
|
encoder_config.query_length = num_query_token |
|
Qformer = BertLMHeadModel(config=encoder_config) |
|
query_tokens = nn.Parameter( |
|
torch.zeros(1, num_query_token, encoder_config.hidden_size) |
|
) |
|
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) |
|
|
|
Qformer.cls = None |
|
Qformer.bert.embeddings.word_embeddings = None |
|
Qformer.bert.embeddings.position_embeddings = None |
|
for layer in Qformer.bert.encoder.layer: |
|
layer.output = None |
|
layer.intermediate = None |
|
|
|
return Qformer, ln_vision, query_tokens |
|
|
|
|
|
def build_adapter_module(cfg, vision_width): |
|
return AdapterModule(cfg, vision_width) |
|
|
|
|
|
class IdentityMap(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, x, *args, **kwargs): |
|
return x |
|
|
|
|
|
class AdapterModule(nn.Module): |
|
def __init__(self, config, vision_width): |
|
super().__init__() |
|
self.adapter_name = config.adapter_module_name |
|
self.config = config |
|
self.output_dim = vision_width |
|
if 'perceiver' in self.adapter_name: |
|
from flash_perceiver import Perceiver |
|
self.adapter = Perceiver( |
|
input_dim=vision_width, |
|
depth=6, |
|
output_dim=vision_width, |
|
num_latents=self.config.num_query_token, |
|
latent_dim=1024, |
|
cross_heads=1, |
|
cross_head_dim=128, |
|
cross_rotary_emb_dim=0, |
|
cross_attn_dropout=0.0, |
|
latent_heads=8, |
|
latent_head_dim=128, |
|
latent_rotary_emb_dim=0, |
|
latent_attn_dropout=0.0, |
|
weight_tie_layers=False, |
|
gated_mlp=True, |
|
self_per_cross_attn=1, |
|
num_zero_tokens=None, |
|
use_flash_attn=True, |
|
) |
|
elif 'naive_resampler' in self.adapter_name: |
|
assert math.sqrt(self.config.num_query_token) ** 2 == self.config.num_query_token, 'num of query need to be a square number' |
|
self.adapter = Resampler( |
|
grid_size=int(math.sqrt(self.config.num_query_token)), |
|
embed_dim=vision_width, |
|
num_heads=8, |
|
) |
|
elif 'qformer' in self.adapter_name: |
|
Qformer, ln_vision, query_tokens = build_Qformer( |
|
self.config.num_query_token, vision_width) |
|
self.adapter = Qformer |
|
self.ln_vision = ln_vision |
|
self.query_tokens = query_tokens |
|
self.output_dim = Qformer.config.hidden_size |
|
elif 'none' in self.adapter_name: |
|
self.adapter = IdentityMap() |
|
|
|
self.is_loaded = False |
|
|
|
if 'compress_token' in self.adapter_name: |
|
match = re.search(r'\d+$', self.adapter_name) |
|
self.token_compressor = TokenCompressor( |
|
num_compressed_token=int(match.group()), |
|
embed_dim=self.config.hidden_size, |
|
num_heads=8, |
|
) |
|
if 'v1' in self.adapter_name: |
|
self.compress_version = 'v1' |
|
else: |
|
self.compress_version = 'v0' |
|
|
|
|
|
self.frame_position_encoding = nn.Embedding( |
|
config.max_num_segments, |
|
self.output_dim, |
|
) |
|
|
|
self.adapter.apply(self._init_weights) |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, (nn.Linear, nn.Embedding)): |
|
trunc_normal_(m.weight, std=.02) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.constant_(m.bias, 0) |
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
def forward(self, image_features, frame_ids): |
|
if 'perceiver' in self.adapter_name: |
|
adapted_image_features = self.adapter(image_features, return_embeddings=True) |
|
elif 'naive_resampler' in self.adapter_name: |
|
adapted_image_features = self.adapter(image_features) |
|
elif 'qformer' in self.adapter_name: |
|
image_features = self.ln_vision(image_features) |
|
query_tokens = self.query_tokens.expand(image_features.shape[0], -1, -1) |
|
attn_mask = torch.ones(image_features.size()[:-1], dtype=torch.long).to(image_features.device) |
|
adapted_image_features = self.adapter.bert( |
|
query_embeds=query_tokens, |
|
encoder_hidden_states=image_features, |
|
encoder_attention_mask=attn_mask, |
|
return_dict=True |
|
).last_hidden_state |
|
elif 'none' in self.adapter_name: |
|
adapted_image_features = self.adapter(image_features) |
|
|
|
frame_embeddings = self.frame_position_encoding(frame_ids).unsqueeze(-2) |
|
adapted_image_features += frame_embeddings |
|
return adapted_image_features |
|
|
|
|
|
def compress_token_per_img(self, batch_image_features): |
|
if 'compress_token' not in self.adapter_name: |
|
return batch_image_features |
|
compressed_features = [] |
|
for image_features in batch_image_features: |
|
|
|
if image_features.shape[1] < self.token_compressor.num_compressed_token: |
|
compressed_features.append(image_features) |
|
else: |
|
compressed_features.append(self.token_compressor(image_features, compress_version=self.compress_version)) |
|
return compressed_features |
|
|
|
|
|
def load_model(self): |
|
if self.is_loaded: |
|
return |
|
|
|
if getattr(self.config, 'adapter_module_path', None): |
|
checkpoint = torch.load(self.config.adapter_module_path, map_location="cpu") |
|
|
|
def get_w(weights, keyword): |
|
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword + '.' in k} |
|
|
|
def get_variable_frame_encoding_w(model_weights, load_weights): |
|
keyword = 'frame_position_encoding' |
|
model_len = model_weights.shape[0] |
|
load_weights_f_encoding = get_w(load_weights, keyword) |
|
|
|
load_len = load_weights_f_encoding['weight'].shape[0] |
|
if model_len <= load_len: |
|
value = load_weights_f_encoding['weight'][:model_len] |
|
else: |
|
value = model_weights.clone().cpu() |
|
value[:load_len] = load_weights_f_encoding['weight'] |
|
return value |
|
|
|
if 'qformer' in self.adapter_name and ('projector.bin' not in self.config.adapter_module_path): |
|
state_dict = checkpoint["model"] |
|
self.adapter.load_state_dict(get_w(state_dict, 'Qformer')) |
|
self.ln_vision.load_state_dict(get_w(state_dict, 'ln_vision')) |
|
self.load_state_dict({'query_tokens': state_dict['query_tokens']}, strict=False) |
|
if getattr(self.config, 'pretrain_mm_mlp_adapter', None): |
|
mm_projector_weights = torch.load(self.config.pretrain_mm_mlp_adapter, map_location='cpu') |
|
frame_encoding_weight = get_variable_frame_encoding_w(self.frame_position_encoding.weight, mm_projector_weights) |
|
self.frame_position_encoding.load_state_dict({'weight': frame_encoding_weight}) |
|
else: |
|
frame_encoding_weight = get_variable_frame_encoding_w(self.frame_position_encoding.weight, checkpoint) |
|
for k in checkpoint.keys(): |
|
if 'frame_position_encoding' in k: |
|
checkpoint[k] = frame_encoding_weight |
|
|
|
self.load_state_dict(get_w(checkpoint, 'adapter_module')) |
|
else: |
|
|
|
return |
|
|
|
def freeze_adapter_module(self, freeze_flag): |
|
if freeze_flag: |
|
for name, p in self.named_parameters(): |
|
p.requires_grad = False |
|
else: |
|
for name, p in self.named_parameters(): |
|
p.requires_grad = True |
|
|
|
if 'naive_resampler' in self.adapter_name: |
|
for name, p in self.named_parameters(): |
|
if 'pos_embed' in name: |
|
p.requires_grad = False |
|
|