# Copyright 2024 Xi Zhang # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn import torchvision.ops as ops import re class TAC(nn.Module): def __init__(self, config): super(TAC,self).__init__() self.mm_hidden_size = config.mm_hidden_size self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads self.dropout = 0.1 self.layers_number = 12 # RAD-DINO hidden layers # LFE self.LFE = nn.Sequential( ops.SqueezeExcitation(self.layers_number,self.layers_number // 2,activation=nn.GELU), nn.Conv2d(self.layers_number,self.layers_number // 2,kernel_size=1,bias=False), ops.SqueezeExcitation(self.layers_number // 2,self.layers_number // 4,activation=nn.GELU), nn.Conv2d(self.layers_number // 2,self.layers_number // 4,kernel_size=1,bias=False), ops.SqueezeExcitation(self.layers_number // 4,1,activation=nn.GELU), nn.Conv2d(self.layers_number // 4,1,kernel_size=1,bias=False) ) self.LFE_prior_bias = nn.Parameter(torch.tensor(0.0, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))) self.LFE_cos = nn.CosineSimilarity(dim=-1, eps=1e-6) # self-attention self.cur_self_attention = nn.MultiheadAttention(embed_dim=(self.mm_hidden_size), num_heads=self.num_attention_heads,batch_first=True,add_bias_kv=True) self.prior_self_attention = nn.MultiheadAttention(embed_dim=(self.mm_hidden_size), num_heads=self.num_attention_heads,batch_first=True,add_bias_kv=True) self.cros_attention = nn.MultiheadAttention(embed_dim=(self.mm_hidden_size), num_heads=self.num_attention_heads,batch_first=True,add_bias_kv=True) self.norm1 = nn.LayerNorm(self.mm_hidden_size) self.norm2 = nn.LayerNorm(self.mm_hidden_size) self.norm3 = nn.LayerNorm(self.mm_hidden_size) self.norm4 = nn.LayerNorm(self.mm_hidden_size) self.mlp_attn = nn.Sequential( nn.Linear(self.mm_hidden_size, self.mm_hidden_size), nn.GELU(), nn.Dropout(self.dropout), nn.Linear(self.mm_hidden_size, self.mm_hidden_size), nn.Dropout(self.dropout) ) self.mlp_final = nn.Sequential( nn.Linear(self.mm_hidden_size, self.hidden_size), nn.GELU(), nn.Linear(self.hidden_size, self.hidden_size), nn.GELU(), nn.Linear(self.hidden_size, self.hidden_size), nn.GELU(), nn.Linear(self.hidden_size, self.hidden_size) ) self.dropout1 = nn.Dropout(self.dropout) self.dropout2 = nn.Dropout(self.dropout) self.dropout3 = nn.Dropout(self.dropout) def calculate_cosine_similarity(self, tensor1, tensor2): assert tensor1.shape == tensor2.shape, "The shapes of the two tensors must be the same" tensor1_flat = tensor1.view(tensor1.size(0), -1) tensor2_flat = tensor2.view(tensor2.size(0), -1) tensor1_flat_normalized = tensor1_flat / tensor1_flat.norm(dim=-1, keepdim=True) tensor2_flat_normalized = tensor2_flat / tensor2_flat.norm(dim=-1, keepdim=True) cosine_similarities = self.LFE_cos(tensor1_flat_normalized, tensor2_flat_normalized) cosine_similarities_normalized = ((cosine_similarities + 1) / 2).pow(8) cosine_similarities_normalized = cosine_similarities_normalized.view(-1, 1, 1) return cosine_similarities_normalized # self-attention block def cur_self_att_block(self,x): x = self.cur_self_attention(x,x,x)[0] return self.dropout1(x) # self-attention block def prior_self_att_block(self,x): x = self.prior_self_attention(x,x,x)[0] return self.dropout2(x) # cross attention block def cros_att_block(self,x,y): x = self.cros_attention(x,y,y)[0] return self.dropout3(x) #TFM def TFM(self,cur_features,prev_features): cur_features_temp = cur_features prev_features_temp = prev_features cos= self.calculate_cosine_similarity(cur_features_temp,prev_features_temp) prev_weight = cos * self.LFE_prior_bias prev_features_temp = prev_features_temp + prev_weight cur_features = self.norm1(cur_features_temp + self.cur_self_att_block(cur_features_temp)) prev_features = self.norm2(prev_features_temp + self.prior_self_att_block(prev_features_temp)) combined_features = self.norm3(cur_features + self.cros_att_block(cur_features,prev_features)) output = self.norm4(cur_features_temp + self.mlp_attn(combined_features)) output = self.mlp_final(output) return output def forward(self, image_features, *args, **kwargs): cur_features, prev_features = image_features cur_features = self.LFE(cur_features).squeeze(1) prev_features= self.LFE(prev_features).squeeze(1) output = self.TFM(cur_features,prev_features) return output @property def config(self): return {"mm_projector_type": 'TAC'} class Projector(nn.Module): def __init__(self, base_projector): super().__init__() self.projector = base_projector def forward(self, image_features, *args, **kwargs): temp_features = image_features[0].squeeze(1) return self.projector(temp_features) def build_vision_projector(config, delay_load=False, *args,**kwargs): projector_type = getattr(config, 'mm_projector_type', 'linear') if projector_type == 'linear': linear_layer = nn.Linear(config.mm_hidden_size, config.hidden_size) return Projector(linear_layer) mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) if mlp_gelu_match: mlp_depth = int(mlp_gelu_match.group(1)) modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] for _ in range(1, mlp_depth): modules.append(nn.GELU()) modules.append(nn.Linear(config.hidden_size, config.hidden_size)) return Projector(nn.Sequential(*modules)) if projector_type == 'TAC': return TAC(config) raise ValueError(f'Unknown projector type: {projector_type}')