multitensor's picture
Upload folder using huggingface_hub
bbfa6f6 verified
import os
import re
import math
import torch
import torch.nn as nn
from .clip_encoder import CLIPVisionTower
from .eva_clip_encoder import EvaClipVisionTower
from .siglip_encoder import SiglipVisionTower
from .google_siglip_encoder import GoogleSiglipVisionTower
from llava.model.utils import LayerNorm
from .qformer import BertConfig, BertLMHeadModel
from .resampler import Resampler, TokenCompressor
from torch.nn.init import trunc_normal_
def build_vision_tower(vision_tower_cfg, **kwargs):
vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
# is_absolute_path_exists = os.path.exists(vision_tower)
if vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
vision_tower = CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
elif vision_tower.startswith("eva"):
vision_tower = EvaClipVisionTower(vision_tower, args=vision_tower_cfg)
elif vision_tower.startswith("google/siglip"):
vision_tower = GoogleSiglipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
elif 'HuggingFaceM4/siglip' in vision_tower:
vision_tower = SiglipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
else:
raise ValueError(f'Unknown vision tower: {vision_tower}')
return vision_tower
def build_Qformer(num_query_token, vision_width, extra_num_query_token=64, cross_attention_freq=2):
ln_vision = LayerNorm(vision_width)
encoder_config = BertConfig.from_pretrained("./model/bert-base-uncased")
encoder_config.encoder_width = vision_width
# insert cross-attention layer every other block
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
Qformer = BertLMHeadModel(config=encoder_config)
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size)
)
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
Qformer.cls = None
Qformer.bert.embeddings.word_embeddings = None
Qformer.bert.embeddings.position_embeddings = None
for layer in Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
return Qformer, ln_vision, query_tokens
#TODO: remove the vision_width here
def build_adapter_module(cfg, vision_width):
return AdapterModule(cfg, vision_width)
class IdentityMap(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
class AdapterModule(nn.Module):
def __init__(self, config, vision_width):
super().__init__()
self.adapter_name = config.adapter_module_name
self.config = config
self.output_dim = vision_width
if 'perceiver' in self.adapter_name:
from flash_perceiver import Perceiver
self.adapter = Perceiver(
input_dim=vision_width,
depth=6,
output_dim=vision_width,
num_latents=self.config.num_query_token,
latent_dim=1024,
cross_heads=1,
cross_head_dim=128,
cross_rotary_emb_dim=0,
cross_attn_dropout=0.0,
latent_heads=8,
latent_head_dim=128,
latent_rotary_emb_dim=0,
latent_attn_dropout=0.0,
weight_tie_layers=False,
gated_mlp=True,
self_per_cross_attn=1,
num_zero_tokens=None,
use_flash_attn=True,
)
elif 'naive_resampler' in self.adapter_name:
assert math.sqrt(self.config.num_query_token) ** 2 == self.config.num_query_token, 'num of query need to be a square number'
self.adapter = Resampler(
grid_size=int(math.sqrt(self.config.num_query_token)),
embed_dim=vision_width,
num_heads=8,
)
elif 'qformer' in self.adapter_name:
Qformer, ln_vision, query_tokens = build_Qformer(
self.config.num_query_token, vision_width)
self.adapter = Qformer
self.ln_vision = ln_vision
self.query_tokens = query_tokens
self.output_dim = Qformer.config.hidden_size
elif 'none' in self.adapter_name:
self.adapter = IdentityMap()
self.is_loaded = False
if 'compress_token' in self.adapter_name:
match = re.search(r'\d+$', self.adapter_name)
self.token_compressor = TokenCompressor(
num_compressed_token=int(match.group()),
embed_dim=self.config.hidden_size,
num_heads=8,
)
if 'v1' in self.adapter_name:
self.compress_version = 'v1'
else:
self.compress_version = 'v0'
# self.ln_vision = LayerNorm(self.config.vision_in_dim)
self.frame_position_encoding = nn.Embedding(
config.max_num_segments,
self.output_dim,
)
self.adapter.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Linear, nn.Embedding)):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, image_features, frame_ids):
if 'perceiver' in self.adapter_name:
adapted_image_features = self.adapter(image_features, return_embeddings=True)
elif 'naive_resampler' in self.adapter_name:
adapted_image_features = self.adapter(image_features)
elif 'qformer' in self.adapter_name:
image_features = self.ln_vision(image_features)
query_tokens = self.query_tokens.expand(image_features.shape[0], -1, -1)
attn_mask = torch.ones(image_features.size()[:-1], dtype=torch.long).to(image_features.device)
adapted_image_features = self.adapter.bert(
query_embeds=query_tokens,
encoder_hidden_states=image_features,
encoder_attention_mask=attn_mask,
return_dict=True
).last_hidden_state
elif 'none' in self.adapter_name:
adapted_image_features = self.adapter(image_features)
frame_embeddings = self.frame_position_encoding(frame_ids).unsqueeze(-2)
adapted_image_features += frame_embeddings
return adapted_image_features
# TODO: addhoc func, rewrite it in the future
def compress_token_per_img(self, batch_image_features):
if 'compress_token' not in self.adapter_name:
return batch_image_features
compressed_features = []
for image_features in batch_image_features: # image_features [num_frames, tokens, C]
# handle non image cases(in that case, image_patch maybe smaller than num_compressed_token)
if image_features.shape[1] < self.token_compressor.num_compressed_token:
compressed_features.append(image_features)
else:
compressed_features.append(self.token_compressor(image_features, compress_version=self.compress_version))
return compressed_features
def load_model(self):
if self.is_loaded:
return
if getattr(self.config, 'adapter_module_path', None):
checkpoint = torch.load(self.config.adapter_module_path, map_location="cpu")
def get_w(weights, keyword):
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword + '.' in k}
def get_variable_frame_encoding_w(model_weights, load_weights):
keyword = 'frame_position_encoding'
model_len = model_weights.shape[0]
load_weights_f_encoding = get_w(load_weights, keyword)
load_len = load_weights_f_encoding['weight'].shape[0]
if model_len <= load_len:
value = load_weights_f_encoding['weight'][:model_len]
else:
value = model_weights.clone().cpu()
value[:load_len] = load_weights_f_encoding['weight']
return value
if 'qformer' in self.adapter_name and ('projector.bin' not in self.config.adapter_module_path):
state_dict = checkpoint["model"]
self.adapter.load_state_dict(get_w(state_dict, 'Qformer'))
self.ln_vision.load_state_dict(get_w(state_dict, 'ln_vision'))
self.load_state_dict({'query_tokens': state_dict['query_tokens']}, strict=False)
if getattr(self.config, 'pretrain_mm_mlp_adapter', None):
mm_projector_weights = torch.load(self.config.pretrain_mm_mlp_adapter, map_location='cpu')
frame_encoding_weight = get_variable_frame_encoding_w(self.frame_position_encoding.weight, mm_projector_weights)
self.frame_position_encoding.load_state_dict({'weight': frame_encoding_weight})
else:
frame_encoding_weight = get_variable_frame_encoding_w(self.frame_position_encoding.weight, checkpoint)
for k in checkpoint.keys():
if 'frame_position_encoding' in k:
checkpoint[k] = frame_encoding_weight
self.load_state_dict(get_w(checkpoint, 'adapter_module'))
else:
# no pertrain weight, use initalization
return
def freeze_adapter_module(self, freeze_flag):
if freeze_flag:
for name, p in self.named_parameters():
p.requires_grad = False
else:
for name, p in self.named_parameters():
p.requires_grad = True
if 'naive_resampler' in self.adapter_name:
for name, p in self.named_parameters():
if 'pos_embed' in name:
p.requires_grad = False