Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import numpy as np | |
import math | |
class lqformerattention(nn.Module): | |
def __init__(self, embed_dim, num_heads, down_dim, up_dim): | |
super().__init__() | |
self.num_heads = num_heads | |
self.down_dim = down_dim | |
self.embed_dim = embed_dim | |
self.down_head_dim = down_dim // num_heads | |
self.head_dim = embed_dim // num_heads | |
self.up_dim = up_dim | |
self.q_proj = nn.Linear(self.down_dim, self.down_dim, bias=True) | |
self.k_proj = nn.Linear(self.down_dim, self.down_dim, bias=True) | |
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) | |
def forward(self, query, key, value, attention_mask=None): | |
bsz, q_len, _ = query.size() | |
k_len = key.size(1) | |
v_len = value.size(1) | |
query = self.q_proj(query).view(bsz, q_len, self.num_heads, self.down_head_dim).transpose(1, 2) | |
key = self.k_proj(key).view(bsz, k_len, self.num_heads, self.down_head_dim).transpose(1, 2) | |
value = self.v_proj(value).view(bsz, v_len, self.num_heads, self.head_dim).transpose(1, 2) | |
attn_weights = torch.matmul( | |
query.to(torch.float32), key.to(torch.float32).transpose(2, 3) | |
) / math.sqrt(self.down_head_dim) | |
if attention_mask is not None: | |
attention_mask = attention_mask.masked_fill(attention_mask == 0, -1e4) | |
attn_weights = attn_weights + attention_mask | |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value.dtype) | |
attn_output = torch.matmul(attn_weights, value) | |
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1) | |
return attn_output, attn_weights | |
class LQFormerLayer(nn.Module): | |
def __init__(self, d_model, mm_model, n_heads, down_dim, up_dim): | |
super(LQFormerLayer, self).__init__() | |
self.t2q_attn = lqformerattention(embed_dim=down_dim, num_heads=n_heads, down_dim=down_dim, up_dim=up_dim) | |
self.i2q_attn = lqformerattention(embed_dim=d_model, num_heads=n_heads, down_dim=down_dim, up_dim=up_dim) | |
self.ln_text = nn.LayerNorm(down_dim) | |
self.ln_q = nn.LayerNorm(down_dim) | |
self.ln_kv = nn.LayerNorm(down_dim) | |
self.n_heads = n_heads | |
def forward(self, learnable_tokens, image_tokens, image_tokens_down, text_tokens, text_mask=None): | |
# Down-project learnable tokens and text tokens | |
# Residual connection for learnable tokens before self-attention | |
residual_learnable = learnable_tokens | |
# Layer norm | |
learnable_tokens = self.ln_q(learnable_tokens) | |
text_tokens = self.ln_text(text_tokens) | |
batch_size = learnable_tokens.size(0) | |
if text_mask is not None: | |
attention_mask = text_mask.unsqueeze(1).unsqueeze(2) # (batch_size, 1, 1, seq_len) | |
attention_mask = attention_mask.repeat(1, self.n_heads, learnable_tokens.size(1), 1) | |
else: | |
attention_mask = None | |
attn_output, _ = self.t2q_attn(query=learnable_tokens, key=text_tokens, value=text_tokens, attention_mask=attention_mask) | |
# Cross-attention: learnable tokens query image tokens | |
image_tokens_down = self.ln_kv(image_tokens_down) | |
attn_output, attention_map = self.i2q_attn(query=attn_output, key=image_tokens_down, value=image_tokens, attention_mask=None) | |
attention_map = torch.mean(attention_map, dim=1) | |
return attn_output, attention_map | |
class LQFormer(nn.Module): | |
def __init__(self, config, num_layers=1): | |
super(LQFormer, self).__init__() | |
self.mm_model = config.hidden_size | |
self.d_model = 1152 | |
self.down_dim = 576 | |
self.down_projector_learnable_text = nn.Linear(self.mm_model, self.down_dim, bias=True) | |
self.down_projector_image = nn.Linear(self.d_model, self.down_dim, bias=True) | |
self.layers = nn.ModuleList([LQFormerLayer(mm_model=self.mm_model, d_model = 1152, n_heads=config.num_attention_heads, down_dim = 576, up_dim = 2560) for _ in range(num_layers)]) | |
self.up_projector = nn.Linear(self.d_model, self.mm_model) | |
def forward(self, learnable_tokens, image_tokens, text_tokens, text_mask=None): | |
learnable_tokens_down = self.down_projector_learnable_text(learnable_tokens) | |
text_tokens_down = self.down_projector_learnable_text(text_tokens) | |
image_tokens_down = self.down_projector_image(image_tokens) | |
# Pass through the layers | |
for layer in self.layers: | |
residual = learnable_tokens | |
learnable_tokens, attention_map = layer(learnable_tokens_down, image_tokens, image_tokens_down, text_tokens_down, text_mask) | |
learnable_tokens = self.up_projector(learnable_tokens) | |
learnable_tokens = residual + learnable_tokens | |
return learnable_tokens | |