Bbmyy
first commit
c92c0ec
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from migc.migc_layers import CBAM, CrossAttention, LayoutAttention
class FourierEmbedder():
def __init__(self, num_freqs=64, temperature=100):
self.num_freqs = num_freqs
self.temperature = temperature
self.freq_bands = temperature ** ( torch.arange(num_freqs) / num_freqs )
@ torch.no_grad()
def __call__(self, x, cat_dim=-1):
out = []
for freq in self.freq_bands:
out.append( torch.sin( freq*x ) )
out.append( torch.cos( freq*x ) )
return torch.cat(out, cat_dim) # torch.Size([5, 30, 64])
class PositionNet(nn.Module):
def __init__(self, in_dim, out_dim, fourier_freqs=8):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy
# -------------------------------------------------------------- #
self.linears_position = nn.Sequential(
nn.Linear(self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
def forward(self, boxes):
# embedding position (it may includes padding as placeholder)
xyxy_embedding = self.fourier_embedder(boxes) # B*1*4 --> B*1*C torch.Size([5, 1, 64])
xyxy_embedding = self.linears_position(xyxy_embedding) # B*1*C --> B*1*768 torch.Size([5, 1, 768])
return xyxy_embedding
class SAC(nn.Module):
def __init__(self, C, number_pro=30):
super().__init__()
self.C = C
self.number_pro = number_pro
self.conv1 = nn.Conv2d(C + 1, C, 1, 1)
self.cbam1 = CBAM(C)
self.conv2 = nn.Conv2d(C, 1, 1, 1)
self.cbam2 = CBAM(number_pro, reduction_ratio=1)
def forward(self, x, guidance_mask, sac_scale=None):
'''
:param x: (B, phase_num, HW, C)
:param guidance_mask: (B, phase_num, H, W)
:return:
'''
B, phase_num, HW, C = x.shape
_, _, H, W = guidance_mask.shape
guidance_mask = guidance_mask.view(guidance_mask.shape[0], phase_num, -1)[
..., None] # (B, phase_num, HW, 1)
null_x = torch.zeros_like(x[:, [0], ...]).to(x.device)
null_mask = torch.zeros_like(guidance_mask[:, [0], ...]).to(guidance_mask.device)
x = torch.cat([x, null_x], dim=1)
guidance_mask = torch.cat([guidance_mask, null_mask], dim=1)
phase_num += 1
scale = torch.cat([x, guidance_mask], dim=-1) # (B, phase_num, HW, C+1)
scale = scale.view(-1, H, W, C + 1) # (B * phase_num, H, W, C+1)
scale = scale.permute(0, 3, 1, 2) # (B * phase_num, C+1, H, W)
scale = self.conv1(scale) # (B * phase_num, C, H, W)
scale = self.cbam1(scale) # (B * phase_num, C, H, W)
scale = self.conv2(scale) # (B * phase_num, 1, H, W)
scale = scale.view(B, phase_num, H, W) # (B, phase_num, H, W)
null_scale = scale[:, [-1], ...]
scale = scale[:, :-1, ...]
x = x[:, :-1, ...]
pad_num = self.number_pro - phase_num + 1
ori_phase_num = scale[:, 1:-1, ...].shape[1]
phase_scale = torch.cat([scale[:, 1:-1, ...], null_scale.repeat(1, pad_num, 1, 1)], dim=1)
shuffled_order = torch.randperm(phase_scale.shape[1])
inv_shuffled_order = torch.argsort(shuffled_order)
random_phase_scale = phase_scale[:, shuffled_order, ...]
scale = torch.cat([scale[:, [0], ...], random_phase_scale, scale[:, [-1], ...]], dim=1)
# (B, number_pro, H, W)
scale = self.cbam2(scale) # (B, number_pro, H, W)
scale = scale.view(B, self.number_pro, HW)[..., None] # (B, number_pro, HW)
random_phase_scale = scale[:, 1: -1, ...]
phase_scale = random_phase_scale[:, inv_shuffled_order[:ori_phase_num], :]
if sac_scale is not None:
instance_num = len(sac_scale)
for i in range(instance_num):
phase_scale[:, i, ...] = phase_scale[:, i, ...] * sac_scale[i]
scale = torch.cat([scale[:, [0], ...], phase_scale, scale[:, [-1], ...]], dim=1)
scale = scale.softmax(dim=1) # (B, phase_num, HW, 1)
out = (x * scale).sum(dim=1, keepdims=True) # (B, 1, HW, C)
return out, scale
class MIGC(nn.Module):
def __init__(self, C, attn_type='base', context_dim=768, heads=8):
super().__init__()
self.ea = CrossAttention(query_dim=C, context_dim=context_dim,
heads=heads, dim_head=C // heads,
dropout=0.0)
self.la = LayoutAttention(query_dim=C,
heads=heads, dim_head=C // heads,
dropout=0.0)
self.norm = nn.LayerNorm(C)
self.sac = SAC(C)
self.pos_net = PositionNet(in_dim=768, out_dim=768)
def forward(self, ca_x, guidance_mask, other_info, return_fuser_info=False):
# x: (B, instance_num+1, HW, C)
# guidance_mask: (B, instance_num, H, W)
# box: (instance_num, 4)
# image_token: (B, instance_num+1, HW, C)
full_H = other_info['height']
full_W = other_info['width']
B, _, HW, C = ca_x.shape
instance_num = guidance_mask.shape[1]
down_scale = int(math.sqrt(full_H * full_W // ca_x.shape[2]))
H = full_H // down_scale
W = full_W // down_scale
guidance_mask = F.interpolate(guidance_mask, size=(H, W), mode='bilinear') # (B, instance_num, H, W)
supplement_mask = other_info['supplement_mask'] # (B, 1, 64, 64)
supplement_mask = F.interpolate(supplement_mask, size=(H, W), mode='bilinear') # (B, 1, H, W)
image_token = other_info['image_token']
assert image_token.shape == ca_x.shape
context = other_info['context_pooler']
box = other_info['box']
box = box.view(B * instance_num, 1, -1)
box_token = self.pos_net(box)
context = torch.cat([context[1:, ...], box_token], dim=1)
ca_scale = other_info['ca_scale'] if 'ca_scale' in other_info else None
ea_scale = other_info['ea_scale'] if 'ea_scale' in other_info else None
sac_scale = other_info['sac_scale'] if 'sac_scale' in other_info else None
ea_x, ea_attn = self.ea(self.norm(image_token[:, 1:, ...].view(B * instance_num, HW, C)),
context=context, return_attn=True)
ea_x = ea_x.view(B, instance_num, HW, C)
ea_x = ea_x * guidance_mask.view(B, instance_num, HW, 1)
ca_x[:, 1:, ...] = ca_x[:, 1:, ...] * guidance_mask.view(B, instance_num, HW, 1) # (B, phase_num, HW, C)
if ca_scale is not None:
assert len(ca_scale) == instance_num
for i in range(instance_num):
ca_x[:, i+1, ...] = ca_x[:, i+1, ...] * ca_scale[i] + ea_x[:, i, ...] * ea_scale[i]
else:
ca_x[:, 1:, ...] = ca_x[:, 1:, ...] + ea_x
ori_image_token = image_token[:, 0, ...] # (B, HW, C)
fusion_template = self.la(x=ori_image_token, guidance_mask=torch.cat([guidance_mask[:, :, ...], supplement_mask], dim=1)) # (B, HW, C)
fusion_template = fusion_template.view(B, 1, HW, C) # (B, 1, HW, C)
ca_x = torch.cat([ca_x, fusion_template], dim = 1)
ca_x[:, 0, ...] = ca_x[:, 0, ...] * supplement_mask.view(B, HW, 1)
guidance_mask = torch.cat([
supplement_mask,
guidance_mask,
torch.ones(B, 1, H, W).to(guidance_mask.device)
], dim=1)
out_MIGC, sac_scale = self.sac(ca_x, guidance_mask, sac_scale=sac_scale)
if return_fuser_info:
fuser_info = {}
fuser_info['sac_scale'] = sac_scale.view(B, instance_num + 2, H, W)
fuser_info['ea_attn'] = ea_attn.mean(dim=1).view(B, instance_num, H, W, 2)
return out_MIGC, fuser_info
else:
return out_MIGC
class NaiveFuser(nn.Module):
def __init__(self):
super().__init__()
def forward(self, ca_x, guidance_mask, other_info, return_fuser_info=False):
# ca_x: (B, instance_num+1, HW, C)
# guidance_mask: (B, instance_num, H, W)
# box: (instance_num, 4)
# image_token: (B, instance_num+1, HW, C)
full_H = other_info['height']
full_W = other_info['width']
B, _, HW, C = ca_x.shape
instance_num = guidance_mask.shape[1]
down_scale = int(math.sqrt(full_H * full_W // ca_x.shape[2]))
H = full_H // down_scale
W = full_W // down_scale
guidance_mask = F.interpolate(guidance_mask, size=(H, W), mode='bilinear') # (B, instance_num, H, W)
guidance_mask = torch.cat([torch.ones(B, 1, H, W).to(guidance_mask.device), guidance_mask * 10], dim=1) # (B, instance_num+1, H, W)
guidance_mask = guidance_mask.view(B, instance_num + 1, HW, 1)
out_MIGC = (ca_x * guidance_mask).sum(dim=1) / (guidance_mask.sum(dim=1) + 1e-6)
if return_fuser_info:
return out_MIGC, None
else:
return out_MIGC