import torch import torch.nn as nn import torch.nn.functional as F from einops.einops import rearrange class Mlp(nn.Module): """Multi-Layer Perceptron (MLP)""" def __init__(self, in_dim, hidden_dim=None, out_dim=None, act_layer=nn.GELU): """ Args: in_dim: input features dimension hidden_dim: hidden features dimension out_dim: output features dimension act_layer: activation function """ super().__init__() out_dim = out_dim or in_dim hidden_dim = hidden_dim or in_dim self.fc1 = nn.Linear(in_dim, hidden_dim) self.act = act_layer() self.fc2 = nn.Linear(hidden_dim, out_dim) self.out_dim = out_dim def forward(self, x): x_size = x.size() x = x.view(-1, x_size[-1]) x = self.fc1(x) x = self.act(x) x = self.fc2(x) x = x.view(*x_size[:-1], self.out_dim) return x class VanillaAttention(nn.Module): def __init__(self, dim, num_heads=8, proj_bias=False): super().__init__() """ Args: dim: feature dimension num_heads: number of attention head proj_bias: bool use query, key, value bias """ self.num_heads = num_heads self.head_dim = dim // num_heads self.softmax_temp = self.head_dim ** -0.5 self.kv_proj = nn.Linear(dim, dim * 2, bias=proj_bias) self.q_proj = nn.Linear(dim, dim, bias=proj_bias) self.merge = nn.Linear(dim, dim) def forward(self, x_q, x_kv=None): """ Args: x_q (torch.Tensor): [N, L, C] x_kv (torch.Tensor): [N, S, C] """ if x_kv is None: x_kv = x_q bs, _, dim = x_q.shape bs, _, dim = x_kv.shape # [N, S, 2, H, D] => [2, N, H, S, D] kv = self.kv_proj(x_kv).reshape(bs, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) # [N, L, H, D] => [N, H, L, D] q = self.q_proj(x_q).reshape(bs, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3) k, v = kv[0].transpose(-2, -1).contiguous(), kv[1].contiguous() # [N, H, D, S], [N, H, S, D] attn = (q @ k) * self.softmax_temp # [N, H, L, S] attn = attn.softmax(dim=-1) x_q = (attn @ v).transpose(1, 2).reshape(bs, -1, dim) x_q = self.merge(x_q) return x_q class CrossBidirectionalAttention(nn.Module): def __init__(self, dim, num_heads, proj_bias = False): super().__init__() """ Args: dim: feature dimension num_heads: number of attention head proj_bias: bool use query, key, value bias """ self.num_heads = num_heads self.head_dim = dim // num_heads self.softmax_temp = self.head_dim ** -0.5 self.qk_proj = nn.Linear(dim, dim, bias=proj_bias) self.v_proj = nn.Linear(dim, dim, bias=proj_bias) self.merge = nn.Linear(dim, dim, bias=proj_bias) self.temperature = nn.Parameter(torch.tensor([0.0]), requires_grad=True) # print(self.temperature) def map_(self, func, x0, x1): return func(x0), func(x1) def forward(self, x0, x1): """ Args: x0 (torch.Tensor): [N, L, C] x1 (torch.Tensor): [N, S, C] """ bs = x0.size(0) qk0, qk1 = self.map_(self.qk_proj, x0, x1) v0, v1 = self.map_(self.v_proj, x0, x1) qk0, qk1, v0, v1 = map( lambda t: t.reshape(bs, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous(), (qk0, qk1, v0, v1)) qk0, qk1 = qk0 * self.softmax_temp**0.5, qk1 * self.softmax_temp**0.5 sim = qk0 @ qk1.transpose(-2,-1).contiguous() attn01 = F.softmax(sim, dim=-1) attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1) x0 = attn01 @ v1 x1 = attn10 @ v0 x0, x1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), x0, x1) x0, x1 = self.map_(self.merge, x0, x1) return x0, x1 class SwinPosEmbMLP(nn.Module): def __init__(self, dim): super().__init__() self.pos_embed = None self.pos_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), nn.ReLU(), nn.Linear(512, dim, bias=False)) def forward(self, x): seq_length = x.shape[1] if self.pos_embed is None or self.training: seq_length = int(seq_length**0.5) coords = torch.arange(0, seq_length, device=x.device, dtype = x.dtype) grid = torch.stack(torch.meshgrid([coords, coords])).contiguous().unsqueeze(0) grid -= seq_length // 2 grid /= (seq_length // 2) self.pos_embed = self.pos_mlp(grid.flatten(2).transpose(1,2)) x = x + self.pos_embed return x class WindowSelfAttention(nn.Module): def __init__(self, dim, num_heads, mlp_hidden_coef, use_pre_pos_embed=False): super().__init__() self.mlp = Mlp(in_dim=dim*2, hidden_dim=dim*mlp_hidden_coef, out_dim=dim, act_layer=nn.GELU) self.gamma = nn.Parameter(torch.ones(dim)) self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.attn = VanillaAttention(dim, num_heads=num_heads) self.pos_embed = SwinPosEmbMLP(dim) self.pos_embed_pre = SwinPosEmbMLP(dim) if use_pre_pos_embed else nn.Identity() def forward(self, x, x_pre): ww = x.shape[1] ww_pre = x_pre.shape[1] x = self.pos_embed(x) x_pre = self.pos_embed_pre(x_pre) x = torch.cat((x, x_pre), dim=1) x = x + self.gamma*self.norm1(self.mlp(torch.cat([x, self.attn(self.norm2(x))], dim=-1))) x, x_pre = x.split([ww, ww_pre], dim=1) return x, x_pre class WindowCrossAttention(nn.Module): def __init__(self, dim, num_heads, mlp_hidden_coef): super().__init__() self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.mlp = Mlp(in_dim=dim*2, hidden_dim=dim*mlp_hidden_coef, out_dim=dim, act_layer=nn.GELU) self.cross_attn = CrossBidirectionalAttention(dim, num_heads=num_heads, proj_bias=False) self.gamma = nn.Parameter(torch.ones(dim)) def forward(self, x0, x1): m_x0, m_x1 = self.cross_attn(self.norm1(x0), self.norm1(x1)) x0 = x0 + self.gamma*self.norm2(self.mlp(torch.cat([x0, m_x0], dim=-1))) x1 = x1 + self.gamma*self.norm2(self.mlp(torch.cat([x1, m_x1], dim=-1))) return x0, x1 class FineProcess(nn.Module): def __init__(self, config): super().__init__() # Config block_dims = config['resnet']['block_dims'] self.block_dims = block_dims self.W_f = config['fine_window_size'] self.W_m = config['medium_window_size'] nhead_f = config["fine"]['nhead_fine_level'] nhead_m = config["fine"]['nhead_medium_level'] mlp_hidden_coef = config["fine"]['mlp_hidden_dim_coef'] # Networks self.conv_merge = nn.Sequential(nn.Conv2d(block_dims[2]*2, block_dims[1], kernel_size=1, stride=1, padding=0, bias=False), nn.Conv2d(block_dims[1], block_dims[1], kernel_size=3, stride=1, padding=1, groups=block_dims[1], bias=False), nn.BatchNorm2d(block_dims[1]) ) self.out_conv_m = nn.Conv2d(block_dims[1], block_dims[1], kernel_size=1, stride=1, padding=0, bias=False) self.out_conv_f = nn.Conv2d(block_dims[0], block_dims[0], kernel_size=1, stride=1, padding=0, bias=False) self.self_attn_m = WindowSelfAttention(block_dims[1], num_heads=nhead_m, mlp_hidden_coef=mlp_hidden_coef, use_pre_pos_embed=False) self.cross_attn_m = WindowCrossAttention(block_dims[1], num_heads=nhead_m, mlp_hidden_coef=mlp_hidden_coef) self.self_attn_f = WindowSelfAttention(block_dims[0], num_heads=nhead_f, mlp_hidden_coef=mlp_hidden_coef, use_pre_pos_embed=True) self.cross_attn_f = WindowCrossAttention(block_dims[0], num_heads=nhead_f, mlp_hidden_coef=mlp_hidden_coef) self.down_proj_m_f = nn.Linear(block_dims[1], block_dims[0], bias=False) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def pre_process(self, feat_f0, feat_f1, feat_m0, feat_m1, feat_c0, feat_c1, feat_c0_pre, feat_c1_pre, data): W_f = self.W_f W_m = self.W_m data.update({'W_f': W_f, 'W_m': W_m}) # merge coarse features before and after loftr layer, and down proj channel dimesions feat_c0 = rearrange(feat_c0, 'n (h w) c -> n c h w', h =data["hw0_c"][0], w =data["hw0_c"][1]) feat_c1 = rearrange(feat_c1, 'n (h w) c -> n c h w', h =data["hw1_c"][0], w =data["hw1_c"][1]) feat_c0 = self.conv_merge(torch.cat([feat_c0, feat_c0_pre], dim=1)) feat_c1 = self.conv_merge(torch.cat([feat_c1, feat_c1_pre], dim=1)) feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) 1 c') feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) 1 c') stride_f = data['hw0_f'][0] // data['hw0_c'][0] stride_m = data['hw0_m'][0] // data['hw0_c'][0] if feat_m0.shape[2] == feat_m1.shape[2] and feat_m0.shape[3] == feat_m1.shape[3]: feat_m = self.out_conv_m(torch.cat([feat_m0, feat_m1], dim=0)) feat_m0, feat_m1 = torch.chunk(feat_m, 2, dim=0) feat_f = self.out_conv_f(torch.cat([feat_f0, feat_f1], dim=0)) feat_f0, feat_f1 = torch.chunk(feat_f, 2, dim=0) else: feat_m0 = self.out_conv_m(feat_m0) feat_m1 = self.out_conv_m(feat_m1) feat_f0 = self.out_conv_f(feat_f0) feat_f1 = self.out_conv_f(feat_f1) # 1. unfold (crop windows) all local windows feat_m0_unfold = F.unfold(feat_m0, kernel_size=(W_m, W_m), stride=stride_m, padding=W_m//2) feat_m0_unfold = rearrange(feat_m0_unfold, 'n (c ww) l -> n l ww c', ww=W_m**2) feat_m1_unfold = F.unfold(feat_m1, kernel_size=(W_m, W_m), stride=stride_m, padding=W_m//2) feat_m1_unfold = rearrange(feat_m1_unfold, 'n (c ww) l -> n l ww c', ww=W_m**2) feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W_f, W_f), stride=stride_f, padding=W_f//2) feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W_f**2) feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W_f, W_f), stride=stride_f, padding=W_f//2) feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W_f**2) # 2. select only the predicted matches feat_c0 = feat_c0[data['b_ids'], data['i_ids']] # [n, ww, cm] feat_c1 = feat_c1[data['b_ids'], data['j_ids']] feat_m0_unfold = feat_m0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cm] feat_m1_unfold = feat_m1_unfold[data['b_ids'], data['j_ids']] feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf] feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']] return feat_c0, feat_c1, feat_m0_unfold, feat_m1_unfold, feat_f0_unfold, feat_f1_unfold def forward(self, feat_f0, feat_f1, feat_m0, feat_m1, feat_c0, feat_c1, feat_c0_pre, feat_c1_pre, data): """ Args: feat_f0 (torch.Tensor): [N, C, H, W] feat_f1 (torch.Tensor): [N, C, H, W] feat_m0 (torch.Tensor): [N, C, H, W] feat_m1 (torch.Tensor): [N, C, H, W] feat_c0 (torch.Tensor): [N, L, C] feat_c1 (torch.Tensor): [N, S, C] feat_c0_pre (torch.Tensor): [N, C, H, W] feat_c1_pre (torch.Tensor): [N, C, H, W] data (dict): with keys ['hw0_c', 'hw1_c', 'hw0_m', 'hw1_m', 'hw0_f', 'hw1_f', 'b_ids', 'j_ids'] """ # TODO: Check for this case if data['b_ids'].shape[0] == 0: feat0 = torch.empty(0, self.W_f**2, self.block_dims[0], device=feat_f0.device) feat1 = torch.empty(0, self.W_f**2, self.block_dims[0], device=feat_f0.device) return feat0, feat1 feat_c0, feat_c1, feat_m0_unfold, feat_m1_unfold, \ feat_f0_unfold, feat_f1_unfold = self.pre_process(feat_f0, feat_f1, feat_m0, feat_m1, feat_c0, feat_c1, feat_c0_pre, feat_c1_pre, data) # self attention (c + m) feat_m_unfold, _ = self.self_attn_m(torch.cat([feat_m0_unfold, feat_m1_unfold], dim=0), torch.cat([feat_c0, feat_c1], dim=0)) feat_m0_unfold, feat_m1_unfold = torch.chunk(feat_m_unfold, 2, dim=0) # cross attention (m0 <-> m1) feat_m0_unfold, feat_m1_unfold = self.cross_attn_m(feat_m0_unfold, feat_m1_unfold) # down proj m feat_m_unfold = self.down_proj_m_f(torch.cat([feat_m0_unfold, feat_m1_unfold], dim=0)) feat_m0_unfold, feat_m1_unfold = torch.chunk(feat_m_unfold, 2, dim=0) # self attention (m + f) feat_f_unfold, _ = self.self_attn_f(torch.cat([feat_f0_unfold, feat_f1_unfold], dim=0), torch.cat([feat_m0_unfold, feat_m1_unfold], dim=0)) feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_f_unfold, 2, dim=0) # cross attention (f0 <-> f1) feat_f0_unfold, feat_f1_unfold = self.cross_attn_f(feat_f0_unfold, feat_f1_unfold) return feat_f0_unfold, feat_f1_unfold