Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import math | |
from migc.migc_layers import CBAM, CrossAttention, LayoutAttention | |
class FourierEmbedder(): | |
def __init__(self, num_freqs=64, temperature=100): | |
self.num_freqs = num_freqs | |
self.temperature = temperature | |
self.freq_bands = temperature ** ( torch.arange(num_freqs) / num_freqs ) | |
def __call__(self, x, cat_dim=-1): | |
out = [] | |
for freq in self.freq_bands: | |
out.append( torch.sin( freq*x ) ) | |
out.append( torch.cos( freq*x ) ) | |
return torch.cat(out, cat_dim) # torch.Size([5, 30, 64]) | |
class PositionNet(nn.Module): | |
def __init__(self, in_dim, out_dim, fourier_freqs=8): | |
super().__init__() | |
self.in_dim = in_dim | |
self.out_dim = out_dim | |
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) | |
self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy | |
# -------------------------------------------------------------- # | |
self.linears_position = nn.Sequential( | |
nn.Linear(self.position_dim, 512), | |
nn.SiLU(), | |
nn.Linear(512, 512), | |
nn.SiLU(), | |
nn.Linear(512, out_dim), | |
) | |
def forward(self, boxes): | |
# embedding position (it may includes padding as placeholder) | |
xyxy_embedding = self.fourier_embedder(boxes) # B*1*4 --> B*1*C torch.Size([5, 1, 64]) | |
xyxy_embedding = self.linears_position(xyxy_embedding) # B*1*C --> B*1*768 torch.Size([5, 1, 768]) | |
return xyxy_embedding | |
class SAC(nn.Module): | |
def __init__(self, C, number_pro=30): | |
super().__init__() | |
self.C = C | |
self.number_pro = number_pro | |
self.conv1 = nn.Conv2d(C + 1, C, 1, 1) | |
self.cbam1 = CBAM(C) | |
self.conv2 = nn.Conv2d(C, 1, 1, 1) | |
self.cbam2 = CBAM(number_pro, reduction_ratio=1) | |
def forward(self, x, guidance_mask, sac_scale=None): | |
''' | |
:param x: (B, phase_num, HW, C) | |
:param guidance_mask: (B, phase_num, H, W) | |
:return: | |
''' | |
B, phase_num, HW, C = x.shape | |
_, _, H, W = guidance_mask.shape | |
guidance_mask = guidance_mask.view(guidance_mask.shape[0], phase_num, -1)[ | |
..., None] # (B, phase_num, HW, 1) | |
null_x = torch.zeros_like(x[:, [0], ...]).to(x.device) | |
null_mask = torch.zeros_like(guidance_mask[:, [0], ...]).to(guidance_mask.device) | |
x = torch.cat([x, null_x], dim=1) | |
guidance_mask = torch.cat([guidance_mask, null_mask], dim=1) | |
phase_num += 1 | |
scale = torch.cat([x, guidance_mask], dim=-1) # (B, phase_num, HW, C+1) | |
scale = scale.view(-1, H, W, C + 1) # (B * phase_num, H, W, C+1) | |
scale = scale.permute(0, 3, 1, 2) # (B * phase_num, C+1, H, W) | |
scale = self.conv1(scale) # (B * phase_num, C, H, W) | |
scale = self.cbam1(scale) # (B * phase_num, C, H, W) | |
scale = self.conv2(scale) # (B * phase_num, 1, H, W) | |
scale = scale.view(B, phase_num, H, W) # (B, phase_num, H, W) | |
null_scale = scale[:, [-1], ...] | |
scale = scale[:, :-1, ...] | |
x = x[:, :-1, ...] | |
pad_num = self.number_pro - phase_num + 1 | |
ori_phase_num = scale[:, 1:-1, ...].shape[1] | |
phase_scale = torch.cat([scale[:, 1:-1, ...], null_scale.repeat(1, pad_num, 1, 1)], dim=1) | |
shuffled_order = torch.randperm(phase_scale.shape[1]) | |
inv_shuffled_order = torch.argsort(shuffled_order) | |
random_phase_scale = phase_scale[:, shuffled_order, ...] | |
scale = torch.cat([scale[:, [0], ...], random_phase_scale, scale[:, [-1], ...]], dim=1) | |
# (B, number_pro, H, W) | |
scale = self.cbam2(scale) # (B, number_pro, H, W) | |
scale = scale.view(B, self.number_pro, HW)[..., None] # (B, number_pro, HW) | |
random_phase_scale = scale[:, 1: -1, ...] | |
phase_scale = random_phase_scale[:, inv_shuffled_order[:ori_phase_num], :] | |
if sac_scale is not None: | |
instance_num = len(sac_scale) | |
for i in range(instance_num): | |
phase_scale[:, i, ...] = phase_scale[:, i, ...] * sac_scale[i] | |
scale = torch.cat([scale[:, [0], ...], phase_scale, scale[:, [-1], ...]], dim=1) | |
scale = scale.softmax(dim=1) # (B, phase_num, HW, 1) | |
out = (x * scale).sum(dim=1, keepdims=True) # (B, 1, HW, C) | |
return out, scale | |
class MIGC(nn.Module): | |
def __init__(self, C, attn_type='base', context_dim=768, heads=8): | |
super().__init__() | |
self.ea = CrossAttention(query_dim=C, context_dim=context_dim, | |
heads=heads, dim_head=C // heads, | |
dropout=0.0) | |
self.la = LayoutAttention(query_dim=C, | |
heads=heads, dim_head=C // heads, | |
dropout=0.0) | |
self.norm = nn.LayerNorm(C) | |
self.sac = SAC(C) | |
self.pos_net = PositionNet(in_dim=768, out_dim=768) | |
def forward(self, ca_x, guidance_mask, other_info, return_fuser_info=False): | |
# x: (B, instance_num+1, HW, C) | |
# guidance_mask: (B, instance_num, H, W) | |
# box: (instance_num, 4) | |
# image_token: (B, instance_num+1, HW, C) | |
full_H = other_info['height'] | |
full_W = other_info['width'] | |
B, _, HW, C = ca_x.shape | |
instance_num = guidance_mask.shape[1] | |
down_scale = int(math.sqrt(full_H * full_W // ca_x.shape[2])) | |
H = full_H // down_scale | |
W = full_W // down_scale | |
guidance_mask = F.interpolate(guidance_mask, size=(H, W), mode='bilinear') # (B, instance_num, H, W) | |
supplement_mask = other_info['supplement_mask'] # (B, 1, 64, 64) | |
supplement_mask = F.interpolate(supplement_mask, size=(H, W), mode='bilinear') # (B, 1, H, W) | |
image_token = other_info['image_token'] | |
assert image_token.shape == ca_x.shape | |
context = other_info['context_pooler'] | |
box = other_info['box'] | |
box = box.view(B * instance_num, 1, -1) | |
box_token = self.pos_net(box) | |
context = torch.cat([context[1:, ...], box_token], dim=1) | |
ca_scale = other_info['ca_scale'] if 'ca_scale' in other_info else None | |
ea_scale = other_info['ea_scale'] if 'ea_scale' in other_info else None | |
sac_scale = other_info['sac_scale'] if 'sac_scale' in other_info else None | |
ea_x, ea_attn = self.ea(self.norm(image_token[:, 1:, ...].view(B * instance_num, HW, C)), | |
context=context, return_attn=True) | |
ea_x = ea_x.view(B, instance_num, HW, C) | |
ea_x = ea_x * guidance_mask.view(B, instance_num, HW, 1) | |
ca_x[:, 1:, ...] = ca_x[:, 1:, ...] * guidance_mask.view(B, instance_num, HW, 1) # (B, phase_num, HW, C) | |
if ca_scale is not None: | |
assert len(ca_scale) == instance_num | |
for i in range(instance_num): | |
ca_x[:, i+1, ...] = ca_x[:, i+1, ...] * ca_scale[i] + ea_x[:, i, ...] * ea_scale[i] | |
else: | |
ca_x[:, 1:, ...] = ca_x[:, 1:, ...] + ea_x | |
ori_image_token = image_token[:, 0, ...] # (B, HW, C) | |
fusion_template = self.la(x=ori_image_token, guidance_mask=torch.cat([guidance_mask[:, :, ...], supplement_mask], dim=1)) # (B, HW, C) | |
fusion_template = fusion_template.view(B, 1, HW, C) # (B, 1, HW, C) | |
ca_x = torch.cat([ca_x, fusion_template], dim = 1) | |
ca_x[:, 0, ...] = ca_x[:, 0, ...] * supplement_mask.view(B, HW, 1) | |
guidance_mask = torch.cat([ | |
supplement_mask, | |
guidance_mask, | |
torch.ones(B, 1, H, W).to(guidance_mask.device) | |
], dim=1) | |
out_MIGC, sac_scale = self.sac(ca_x, guidance_mask, sac_scale=sac_scale) | |
if return_fuser_info: | |
fuser_info = {} | |
fuser_info['sac_scale'] = sac_scale.view(B, instance_num + 2, H, W) | |
fuser_info['ea_attn'] = ea_attn.mean(dim=1).view(B, instance_num, H, W, 2) | |
return out_MIGC, fuser_info | |
else: | |
return out_MIGC | |
class NaiveFuser(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, ca_x, guidance_mask, other_info, return_fuser_info=False): | |
# ca_x: (B, instance_num+1, HW, C) | |
# guidance_mask: (B, instance_num, H, W) | |
# box: (instance_num, 4) | |
# image_token: (B, instance_num+1, HW, C) | |
full_H = other_info['height'] | |
full_W = other_info['width'] | |
B, _, HW, C = ca_x.shape | |
instance_num = guidance_mask.shape[1] | |
down_scale = int(math.sqrt(full_H * full_W // ca_x.shape[2])) | |
H = full_H // down_scale | |
W = full_W // down_scale | |
guidance_mask = F.interpolate(guidance_mask, size=(H, W), mode='bilinear') # (B, instance_num, H, W) | |
guidance_mask = torch.cat([torch.ones(B, 1, H, W).to(guidance_mask.device), guidance_mask * 10], dim=1) # (B, instance_num+1, H, W) | |
guidance_mask = guidance_mask.view(B, instance_num + 1, HW, 1) | |
out_MIGC = (ca_x * guidance_mask).sum(dim=1) / (guidance_mask.sum(dim=1) + 1e-6) | |
if return_fuser_info: | |
return out_MIGC, None | |
else: | |
return out_MIGC |