Last commit not found
from .modeling_bert import BertModel | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class MultiModalBert(BertModel): | |
def __init__(self, config, embed_dim, pwam_idx=[3,6,9,12], num_heads_fusion=[1,1,1,1], fusion_drop=0.0): | |
super().__init__(config) | |
self.pwam_idx = pwam_idx | |
self.num_heads_fusion = num_heads_fusion | |
self.fusion_drop = fusion_drop | |
pwam_dims=[embed_dim * 2** i for i in range(len(pwam_idx))] | |
#print(pwam_dims) | |
self.pwams = nn.ModuleList() | |
self.res_gates = nn.ModuleList() | |
self.norms = nn.ModuleList() | |
for i in range(0, len(pwam_idx)): | |
dim = pwam_dims[i] | |
fusion = PWAM(768, # both the visual input and for combining, num of channels | |
dim, # v_in | |
768, # l_in | |
768, # key | |
768, # value | |
num_heads=num_heads_fusion[i], | |
dropout=fusion_drop) | |
self.pwams.append(fusion) | |
res_gate = nn.Sequential( | |
nn.Linear(768, 768, bias=False), | |
nn.ReLU(), | |
nn.Linear(768, 768, bias=False), | |
nn.Tanh() | |
) | |
nn.init.zeros_(res_gate[0].weight) | |
nn.init.zeros_(res_gate[2].weight) | |
self.res_gates.append(res_gate) | |
self.norms.append(nn.LayerNorm(768)) | |
def forward_stem(self, input_ids, attention_mask): | |
input_shape = input_ids.size() | |
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device) | |
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, input_ids.device) | |
embedding_output = self.embeddings( | |
input_ids=input_ids, token_type_ids=token_type_ids | |
) | |
#print(embedding_output.shape, extended_attention_mask.shape, "?>>>") | |
return embedding_output, extended_attention_mask | |
def forward_stage1(self, hidden_states, attention_mask): | |
for i in range(0, self.pwam_idx[0]): | |
layer_module = self.encoder.layer[i] | |
layer_outputs = layer_module( | |
hidden_states, | |
attention_mask, | |
) | |
hidden_states = layer_outputs[0] | |
return layer_outputs[0] | |
def forward_stage2(self, hidden_states, attention_mask): | |
for i in range(self.pwam_idx[0], self.pwam_idx[1]): | |
layer_module = self.encoder.layer[i] | |
layer_outputs = layer_module( | |
hidden_states, | |
attention_mask, | |
) | |
hidden_states = layer_outputs[0] | |
return layer_outputs[0] | |
def forward_stage3(self, hidden_states, attention_mask): | |
for i in range(self.pwam_idx[1], self.pwam_idx[2]): | |
layer_module = self.encoder.layer[i] | |
layer_outputs = layer_module( | |
hidden_states, | |
attention_mask, | |
) | |
hidden_states = layer_outputs[0] | |
return layer_outputs[0] | |
def forward_stage4(self, hidden_states, attention_mask): | |
for i in range(self.pwam_idx[2], self.pwam_idx[3]): | |
layer_module = self.encoder.layer[i] | |
layer_outputs = layer_module( | |
hidden_states, | |
attention_mask, | |
) | |
hidden_states = layer_outputs[0] | |
return layer_outputs[0] | |
def forward_pwam1(self, x, l, l_mask): | |
l_residual = self.pwams[0](x, l, l_mask) | |
l = l + (self.res_gates[0](l_residual) * l_residual) | |
return self.norms[0](l_residual), l | |
def forward_pwam2(self, x, l, l_mask): | |
l_residual = self.pwams[1](x, l, l_mask) | |
l = l + (self.res_gates[1](l_residual) * l_residual) | |
return self.norms[1](l_residual), l | |
def forward_pwam3(self, x, l, l_mask): | |
l_residual = self.pwams[2](x, l, l_mask) | |
l = l + (self.res_gates[2](l_residual) * l_residual) | |
return self.norms[2](l_residual), l | |
def forward_pwam4(self, x, l, l_mask): | |
l_residual = self.pwams[3](x, l, l_mask) | |
l = l + (self.res_gates[3](l_residual) * l_residual) | |
return self.norms[3](l_residual), l | |
class PWAM(nn.Module): | |
def __init__(self, dim, v_in_channels, l_in_channels, key_channels, value_channels, num_heads=0, dropout=0.0): | |
super(PWAM, self).__init__() | |
# input x shape: (B, H*W, dim) | |
#self.vis_project = nn.Sequential(nn.Conv1d(dim, dim, 1, 1), # the init function sets bias to 0 if bias is True | |
# nn.GELU(), | |
# nn.Dropout(dropout) | |
# ) | |
#self.vis_project = nn.Sequential(nn.Conv1d(dim, dim, 1, 1), # the init function sets bias to 0 if bias is True | |
self.vis_project = nn.Sequential(nn.Linear(dim, dim), # the init function sets bias to 0 if bias is True | |
nn.GELU(), | |
nn.Dropout(dropout) | |
) | |
self.image_lang_att = SpatialImageLanguageAttention(v_in_channels, # v_in | |
l_in_channels, # l_in | |
key_channels, # key | |
value_channels, # value | |
out_channels=value_channels, # out | |
num_heads=num_heads) | |
self.project_mm = nn.Sequential(nn.Conv1d(value_channels, value_channels, 1, 1), | |
nn.GELU(), | |
nn.Dropout(dropout) | |
) | |
def forward(self, x, l, l_mask): | |
# input x shape: (B, H*W, dim) | |
#print("???", x.shape, l.shape, l_mask.shape) | |
#print(self.vis_project) | |
#vis = self.vis_project(x.permute(0, 2, 1)) # (B, dim, H*W) | |
vis = self.vis_project(l) # (B, dim, H*W) | |
lang = self.image_lang_att(x, l, l_mask) # (B, H*W, dim) | |
lang = lang.permute(0, 2, 1) # (B, dim, H*W) | |
#print("vis", vis.shape, "lang", lang.shape) | |
mm = torch.mul(vis.permute(0,2,1), lang) | |
#print(mm.shape) | |
mm = self.project_mm(mm) # (B, dim, H*W) | |
mm = mm.permute(0, 2, 1) # (B, H*W, dim) | |
return mm | |
#self.fusion = PWAM(dim, # both the visual input and for combining, num of channels | |
# dim, # v_in | |
# 768, # l_in | |
# dim, # key | |
# dim, # value | |
# num_heads=num_heads_fusion, | |
# dropout=fusion_drop) | |
class SpatialImageLanguageAttention(nn.Module): | |
def __init__(self, v_in_channels, l_in_channels, key_channels, value_channels, out_channels=None, num_heads=1): | |
super(SpatialImageLanguageAttention, self).__init__() | |
# x shape: (B, H*W, v_in_channels) | |
# l input shape: (B, l_in_channels, N_l) | |
# l_mask shape: (B, N_l, 1) | |
self.v_in_channels = v_in_channels | |
self.l_in_channels = l_in_channels | |
self.out_channels = out_channels | |
self.key_channels = key_channels | |
self.value_channels = value_channels | |
self.num_heads = num_heads | |
if out_channels is None: | |
self.out_channels = self.value_channels | |
# Keys: language features: (B, l_in_channels, #words) | |
# avoid any form of spatial normalization because a sentence contains many padding 0s | |
self.f_query = nn.Sequential( | |
nn.Conv1d(self.l_in_channels, self.key_channels, kernel_size=1, stride=1), | |
) | |
# Queries: visual features: (B, H*W, v_in_channels) | |
self.f_key = nn.Sequential( | |
nn.Conv1d(self.v_in_channels, self.key_channels, kernel_size=1, stride=1), | |
nn.InstanceNorm1d(self.key_channels), | |
) | |
# Values: language features: (B, l_in_channels, #words) | |
#self.f_value = nn.Sequential( | |
# nn.Conv1d(self.l_in_channels, self.value_channels, kernel_size=1, stride=1), | |
#) | |
self.f_value = nn.Sequential( | |
nn.Conv1d(self.v_in_channels, self.key_channels, kernel_size=1, stride=1), | |
nn.InstanceNorm1d(self.key_channels), | |
) | |
# Out projection | |
self.W = nn.Sequential( | |
nn.Conv1d(self.value_channels, self.out_channels, kernel_size=1, stride=1), | |
nn.InstanceNorm1d(self.out_channels), | |
) | |
def forward(self, x, l, l_mask): | |
#print('input shape', x.shape, l.shape, l_mask.shape) | |
l_mask = l_mask.squeeze(1) | |
# x shape: (B, H*W, v_in_channels) | |
# l input shape: (B, l_in_channels, N_l) | |
# l_mask shape: (B, N_l, 1) | |
B, HW = x.size(0), x.size(1) | |
x = x.permute(0, 2, 1) # (B, key_channels, H*W) | |
l = l.permute(0,2,1) | |
#l_mask = l_mask.permute(0, 2, 1) # (B, N_l, 1) -> (B, 1, N_l) | |
l_mask = l_mask # (B, N_l, 1) -> (B, 1, N_l) | |
#query = self.f_query(x) # (B, key_channels, H*W) if Conv1D | |
#query = query.permute(0, 2, 1) # (B, H*W, key_channels) | |
#key = self.f_key(l) # (B, key_channels, N_l) | |
#value = self.f_value(l) # (B, self.value_channels, N_l) | |
#key = key * l_mask # (B, key_channels, N_l) | |
#value = value * l_mask # (B, self.value_channels, N_l) | |
#print(l.shape, self.f_query) | |
query = self.f_query(l) # (B, key_channels, H*W) if Conv1D | |
query = query * l_mask # (B, key_channels, N_l) | |
query = query.permute(0, 2, 1) # (B, N_l, key_channels) | |
key = self.f_key(x) # (B, key_channels, H*W) if Conv1D | |
value = self.f_value(x) # (B, key_channels, H*W) if Conv1D | |
n_l = query.size(1) | |
#print(query.shape, key.shape, value.shape) | |
#query = query.reshape(B, HW, self.num_heads, self.key_channels//self.num_heads).permute(0, 2, 1, 3) | |
# (b, num_heads, H*W, self.key_channels//self.num_heads) | |
#key = key.reshape(B, self.num_heads, self.key_channels//self.num_heads, n_l) | |
# (b, num_heads, self.key_channels//self.num_heads, n_l) | |
#value = value.reshape(B, self.num_heads, self.value_channels//self.num_heads, n_l) | |
# # (b, num_heads, self.value_channels//self.num_heads, n_l) | |
key = key.reshape(B, self.num_heads, self.key_channels//self.num_heads, HW) | |
value = value.reshape(B, self.num_heads, self.key_channels//self.num_heads, HW) | |
# (b, num_heads, H*W, self.key_channels//self.num_heads) | |
#query = query.reshape(B, self.num_heads, self.key_channels//self.num_heads, n_l) | |
query = query.reshape(B, n_l, self.num_heads, self.key_channels//self.num_heads).permute(0, 2, 1, 3) | |
# (b, num_heads, self.key_channels//self.num_heads, n_l) | |
#value = value.reshape(B, self.num_heads, self.value_channels//self.num_heads, n_l) | |
#print('after reshape', query.shape, key.shape, value.shape) | |
l_mask = l_mask.unsqueeze(-1) # (b, 1, 1, n_l) | |
#sim_map = torch.matmul(query, key) # (B, self.num_heads, H*W, N_l) | |
sim_map = torch.matmul(query, key) # (B, self.num_heads, N_l, H*W) | |
sim_map = (self.key_channels ** -.5) * sim_map # scaled dot product | |
sim_map = sim_map + (1e4*l_mask - 1e4) # assign a very small number to padding positions | |
sim_map = F.softmax(sim_map, dim=-1) # (B, num_heads, h*w, N_l) | |
out = torch.matmul(sim_map, value.permute(0, 1, 3, 2)) # (B, num_heads, H*W, self.value_channels//num_heads) | |
#print('out', out.shape) | |
#out = out.permute(0, 2, 1, 3).contiguous().reshape(B, HW, self.value_channels) # (B, H*W, value_channels) | |
out = out.permute(0, 2, 1, 3).contiguous().reshape(B, n_l, self.value_channels) # (B, H*W, value_channels) | |
out = out.permute(0, 2, 1) # (B, value_channels, HW) | |
out = self.W(out) # (B, value_channels, HW) | |
out = out.permute(0, 2, 1) # (B, HW, value_channels) | |
return out | |