|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import numpy as np |
|
|
|
|
|
class DotAttn(nn.Module): |
|
""" Dot-Attention """ |
|
|
|
def forward(self, inp, h): |
|
score = self.softmax(inp, h) |
|
return score.expand_as(inp).mul(inp).sum(1), score |
|
|
|
def softmax(self, inp, h): |
|
raw_score = inp.bmm(h.unsqueeze(2)) |
|
score = F.softmax(raw_score, dim=1) |
|
return score |
|
|
|
|
|
class ScaledDotAttn(nn.Module): |
|
""" Scaled Dot-Attention """ |
|
|
|
def forward(self, inp, h): |
|
score = self.softmax(inp, h) |
|
return score.expand_as(inp).mul(inp).sum(1), score |
|
|
|
def softmax(self, inp, h): |
|
raw_score = inp.bmm(h.unsqueeze(2)) / np.sqrt(h.shape[-1]) |
|
score = F.softmax(raw_score, dim=1) |
|
return score |
|
|
|
|
|
class Fusion(nn.Module): |
|
""" Base Fusion Class""" |
|
|
|
def __init__(self, input_dim=3): |
|
super().__init__() |
|
self.input_dim = input_dim |
|
|
|
def tile_x2(self, x1, x2, x2_proj=None): |
|
if x2_proj: |
|
x2 = x2_proj(x2) |
|
|
|
x2 = x2.unsqueeze(-1).unsqueeze(-1) |
|
x2 = x2.repeat(x1.shape[0], 1, x1.shape[-2], x1.shape[-1]) |
|
return x2 |
|
|
|
def batch_tile_x2(self, x1, x2, x2_proj=None): |
|
if x2_proj: |
|
x2 = x2_proj(x2) |
|
|
|
x2 = x2.unsqueeze(-1).unsqueeze(-1) |
|
x2 = x2.repeat(1, 1, x1.shape[-2], x1.shape[-1]) |
|
return x2 |
|
|
|
def forward(self, x1, x2, x2_mask=None, x2_proj=None): |
|
raise NotImplementedError() |
|
|
|
|
|
class FusionAdd(Fusion): |
|
""" x1 + x2 """ |
|
|
|
def __init__(self, input_dim=3): |
|
super(FusionAdd, self).__init__(input_dim=input_dim) |
|
|
|
def forward(self, x1, x2, x2_mask=None, x2_proj=None): |
|
if x1.shape != x2.shape and len(x1.shape) != len(x2.shape): |
|
x2 = self.tile_x2(x1, x2, x2_proj) |
|
return x1 + x2 |
|
|
|
|
|
class FusionMult(Fusion): |
|
""" x1 * x2 """ |
|
|
|
def __init__(self, input_dim=3): |
|
super(FusionMult, self).__init__(input_dim=input_dim) |
|
|
|
def forward(self, x1, x2, x2_mask=None, x2_proj=None): |
|
if x1.shape != x2.shape and len(x1.shape) != len(x2.shape): |
|
x2 = self.batch_tile_x2(x1, x2, x2_proj) |
|
return x1 * x2 |
|
|
|
|
|
class FusionMax(Fusion): |
|
""" max(x1, x2) """ |
|
|
|
def __init__(self, input_dim=3): |
|
super(FusionMax, self).__init__(input_dim=input_dim) |
|
|
|
def forward(self, x1, x2, x2_mask=None, x2_proj=None): |
|
if x1.shape != x2.shape and len(x1.shape) != len(x2.shape): |
|
x2 = self.tile_x2(x1, x2, x2_proj) |
|
return torch.max(x1, x2) |
|
|
|
|
|
class FusionConcat(Fusion): |
|
""" [x1; x2] """ |
|
|
|
def __init__(self, input_dim=3): |
|
super(FusionConcat, self).__init__(input_dim=input_dim) |
|
|
|
def forward(self, x1, x2, x2_mask=None, x2_proj=None): |
|
if x1.shape != x2.shape and len(x1.shape) != len(x2.shape): |
|
x2 = self.tile_x2(x1, x2, x2_proj) |
|
return torch.cat([x1, x2], dim=1) |
|
|
|
|
|
class FusionConv(Fusion): |
|
""" 1x1 convs after [x1; x2] """ |
|
|
|
def __init__(self, input_dim=3): |
|
super(FusionConv, self).__init__(input_dim=input_dim) |
|
self.conv = nn.Sequential( |
|
nn.ReLU(True), |
|
nn.Conv2d(input_dim * 2, input_dim, kernel_size=1, bias=False) |
|
) |
|
|
|
def forward(self, x1, x2, x2_mask=None, x2_proj=None): |
|
if x1.shape != x2.shape and len(x1.shape) != len(x2.shape): |
|
x2 = self.tile_x2(x1, x2, x2_proj) |
|
x = torch.cat([x1, x2], dim=1) |
|
x = self.conv(x) |
|
return x |
|
|
|
|
|
class FusionConvLat(Fusion): |
|
""" 1x1 convs after [x1; x2] for lateral fusion """ |
|
|
|
def __init__(self, input_dim=3, output_dim=3): |
|
super(FusionConvLat, self).__init__(input_dim=input_dim) |
|
self.conv = nn.Sequential( |
|
nn.ReLU(True), |
|
nn.Conv2d(input_dim, output_dim, kernel_size=1, bias=False) |
|
) |
|
|
|
def forward(self, x1, x2, x2_mask=None, x2_proj=None): |
|
if x1.shape != x2.shape and len(x1.shape) != len(x2.shape): |
|
x2 = self.tile_x2(x1, x2, x2_proj) |
|
x = torch.cat([x1, x2], dim=1) |
|
x = self.conv(x) |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FusionFiLM(Fusion): |
|
""" FiLM (Perez et. al, https://arxiv.org/abs/1709.07871). |
|
Note: This is not used inside a Residual block before ReLU. |
|
I had a version this in UpBlock with FiLM, which didn't seem to work at all. |
|
""" |
|
|
|
def __init__(self, input_dim=3, output_dim=3): |
|
super(FusionFiLM, self).__init__(input_dim=input_dim) |
|
|
|
def forward(self, x1, x2, gamma, beta): |
|
g = self.tile_x2(x1, x2, gamma) |
|
b = self.tile_x2(x1, x2, beta) |
|
return x1 * g + b |
|
|
|
|
|
class FusionDeepConv(Fusion): |
|
""" Multi-Layer 1x1 convs after [x1; x2] """ |
|
|
|
def __init__(self, input_dim=3): |
|
super(FusionDeepConv, self).__init__(input_dim=input_dim) |
|
self.conv = nn.Sequential( |
|
nn.ReLU(True), |
|
nn.Conv2d(input_dim * 2, input_dim, kernel_size=1, bias=False), |
|
nn.ReLU(True), |
|
nn.Conv2d(input_dim, input_dim, kernel_size=1, bias=False), |
|
nn.ReLU(True), |
|
nn.Conv2d(input_dim, input_dim, kernel_size=1, bias=False), |
|
) |
|
|
|
def forward(self, x1, x2, x2_mask=None, x2_proj=None): |
|
if x1.shape != x2.shape and len(x1.shape) != len(x2.shape): |
|
x2 = self.tile_x2(x1, x2, x2_proj) |
|
x = torch.cat([x1, x2], dim=1) |
|
x = self.conv(x) |
|
return x |
|
|
|
|
|
class FusionMultWord(nn.Module): |
|
""" Product with weighted-sum of words """ |
|
|
|
def __init__(self, input_dim=3): |
|
super().__init__() |
|
self.input_dim = input_dim |
|
|
|
def forward(self, x1, x2, x2_mask=None, x2_proj=None): |
|
B, D, H, W = x1.shape |
|
x2_len = int(x2_mask.count_nonzero()) |
|
|
|
weighted_x1 = torch.zeros_like(x1) |
|
for t in range(x2_len): |
|
x2_t = x2_proj(x2[:,t]) if x2_proj else x2[:,t] |
|
x2_t = x2_t.unsqueeze(-1).unsqueeze(-1).repeat(B, 1, H, W) |
|
weighted_x1 += x1 * x2_t |
|
weighted_x1 /= x2_len |
|
return weighted_x1 |
|
|
|
|
|
class FusionWordAttention(nn.Module): |
|
""" Word Attention """ |
|
|
|
def __init__(self, input_dim=3): |
|
super().__init__() |
|
self.input_dim = input_dim |
|
self.dot_attn = DotAttn() |
|
|
|
def forward(self, x1, x2, x2_mask=None, x2_proj=None): |
|
B, D, H, W = x1.shape |
|
x1_flat = x1.reshape(B, D, H*W) |
|
x2_len = int(x2_mask.count_nonzero()) |
|
|
|
|
|
weight_sum_x1_flat = torch.zeros_like(x1_flat) |
|
for t in range(x2_len): |
|
x2_t = x2_proj(x2[:,t]) if x2_proj else x2[:,t] |
|
x2_t = x2_t.repeat(B, 1) |
|
|
|
_, attn_x1 = self.dot_attn(x1_flat.transpose(1, 2), x2_t) |
|
weight_sum_x1_flat += x1_flat * attn_x1.transpose(1, 2) |
|
|
|
weight_sum_x1_flat /= x2_len |
|
x2 = weight_sum_x1_flat.reshape(B, D, H, W) |
|
return x2 |
|
|
|
|
|
class FusionSentenceAttention(nn.Module): |
|
""" Sentence Attention """ |
|
|
|
def __init__(self, input_dim=3): |
|
super().__init__() |
|
self.input_dim = input_dim |
|
self.dot_attn = ScaledDotAttn() |
|
|
|
def forward(self, x1, x2, x2_mask=None, x2_proj=None): |
|
B, D, H, W = x1.shape |
|
x1_flat = x1.reshape(B, D, H*W) |
|
|
|
x2_t = x2_proj(x2) if x2_proj else x2 |
|
x2_t = x2_t.repeat(B, 1) |
|
|
|
_, attn_x1 = self.dot_attn(x1_flat.transpose(1, 2), x2_t) |
|
weight_sum_x1_flat = x1_flat * attn_x1.transpose(1, 2) |
|
|
|
x2 = weight_sum_x1_flat.reshape(B, D, H, W) |
|
return x2 |
|
|
|
|
|
class CrossModalAttention2d(nn.Module): |
|
""" Cross-Modal Attention. Adapted from: https://github.com/openai/CLIP/blob/main/clip/model.py#L56 """ |
|
|
|
def __init__(self, spacial_dim=7, embed_dim=1024, num_heads=32, |
|
output_dim=1024, lang_dim=512, lang_max_tokens=77): |
|
super().__init__() |
|
self.embed_dim = embed_dim |
|
self.lang_dim = lang_dim |
|
self.lang_max_tokens = lang_max_tokens |
|
self.num_heads = num_heads |
|
self.lang_proj = nn.Linear(self.lang_dim, embed_dim) |
|
self.vision_positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2, embed_dim) / embed_dim ** 0.5) |
|
self.lang_positional_embedding = nn.Parameter(torch.randn(lang_max_tokens, embed_dim) / embed_dim ** 0.5) |
|
self.k_proj = nn.Linear(embed_dim, embed_dim) |
|
self.q_proj = nn.Linear(embed_dim, embed_dim) |
|
self.v_proj = nn.Linear(embed_dim, embed_dim) |
|
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) |
|
|
|
def forward(self, x, l, l_mask): |
|
|
|
x_shape = x.shape |
|
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) |
|
x = x + self.vision_positional_embedding[:x.shape[0], None, :].to(x.dtype) |
|
|
|
|
|
l = l.permute(1, 0, 2) |
|
l_shape = l.shape |
|
l = l.reshape(-1, self.lang_dim) |
|
l = self.lang_proj(l) |
|
l = l.reshape(l_shape[0], l_shape[1], self.embed_dim) |
|
l = l + self.lang_positional_embedding[:, None, :].to(l.dtype) |
|
|
|
|
|
l_len = int(l_mask.count_nonzero()) |
|
l = l[:l_len] |
|
l = l.repeat(1, x.shape[1], 1) |
|
|
|
x, _ = F.multi_head_attention_forward( |
|
query=x, key=l, value=l, |
|
embed_dim_to_check=x.shape[-1], |
|
num_heads=self.num_heads, |
|
q_proj_weight=self.q_proj.weight, |
|
k_proj_weight=self.k_proj.weight, |
|
v_proj_weight=self.v_proj.weight, |
|
in_proj_weight=None, |
|
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), |
|
bias_k=None, |
|
bias_v=None, |
|
add_zero_attn=False, |
|
dropout_p=0, |
|
out_proj_weight=self.c_proj.weight, |
|
out_proj_bias=self.c_proj.bias, |
|
use_separate_proj_weight=True, |
|
training=self.training, |
|
need_weights=False |
|
) |
|
|
|
x = x.permute(1, 2, 0) |
|
x = x.reshape(x_shape) |
|
return x |
|
|
|
|
|
class FusionMultiHeadedWordAttention(nn.Module): |
|
""" Multi-Headed Word Attention that uses Cross Modal Attention at different scales """ |
|
|
|
def __init__(self, input_dim=3): |
|
super().__init__() |
|
self.input_dim = input_dim |
|
self.attn1 = CrossModalAttention2d(spacial_dim=7, embed_dim=1024, output_dim=1024) |
|
self.attn2 = CrossModalAttention2d(spacial_dim=14, embed_dim=512, output_dim=512) |
|
self.attn3 = CrossModalAttention2d(spacial_dim=28, embed_dim=256, output_dim=256) |
|
|
|
self.multi_headed_attns = { |
|
1024: self.attn1, |
|
512: self.attn2, |
|
256: self.attn3, |
|
} |
|
|
|
def forward(self, x1, x2, x2_mask=None, x2_proj=None): |
|
emb_dim = x1.shape[1] |
|
x = self.multi_headed_attns[emb_dim](x1, x2, x2_mask) |
|
return x |
|
|
|
|
|
names = { |
|
'add': FusionAdd, |
|
'mult': FusionMult, |
|
'mult_word': FusionMultWord, |
|
'film': FusionFiLM, |
|
'max': FusionMax, |
|
'concat': FusionConcat, |
|
'conv': FusionConv, |
|
'deep_conv': FusionDeepConv, |
|
'word_attn': FusionWordAttention, |
|
'sent_attn': FusionSentenceAttention, |
|
'multi_headed_word_attn': FusionMultiHeadedWordAttention, |
|
} |
|
|
|
|