Realcat
add: efficientloftr
e02ffe6
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from .linear_attention import Attention, crop_feature, pad_feature
from einops.einops import rearrange
from collections import OrderedDict
from ..utils.position_encoding import RoPEPositionEncodingSine
import numpy as np
from loguru import logger
class AG_RoPE_EncoderLayer(nn.Module):
def __init__(self,
d_model,
nhead,
agg_size0=4,
agg_size1=4,
no_flash=False,
rope=False,
npe=None,
fp32=False,
):
super(AG_RoPE_EncoderLayer, self).__init__()
self.dim = d_model // nhead
self.nhead = nhead
self.agg_size0, self.agg_size1 = agg_size0, agg_size1
self.rope = rope
# aggregate and position encoding
self.aggregate = nn.Conv2d(d_model, d_model, kernel_size=agg_size0, padding=0, stride=agg_size0, bias=False, groups=d_model) if self.agg_size0 != 1 else nn.Identity()
self.max_pool = torch.nn.MaxPool2d(kernel_size=self.agg_size1, stride=self.agg_size1) if self.agg_size1 != 1 else nn.Identity()
self.rope_pos_enc = RoPEPositionEncodingSine(d_model, max_shape=(256, 256), npe=npe, ropefp16=True)
# multi-head attention
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.attention = Attention(no_flash, self.nhead, self.dim, fp32)
self.merge = nn.Linear(d_model, d_model, bias=False)
# feed-forward network
self.mlp = nn.Sequential(
nn.Linear(d_model*2, d_model*2, bias=False),
nn.LeakyReLU(inplace = True),
nn.Linear(d_model*2, d_model, bias=False),
)
# norm
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, source, x_mask=None, source_mask=None):
"""
Args:
x (torch.Tensor): [N, C, H0, W0]
source (torch.Tensor): [N, C, H1, W1]
x_mask (torch.Tensor): [N, H0, W0] (optional) (L = H0*W0)
source_mask (torch.Tensor): [N, H1, W1] (optional) (S = H1*W1)
"""
bs, C, H0, W0 = x.size()
H1, W1 = source.size(-2), source.size(-1)
# Aggragate feature
assert x_mask is None and source_mask is None
query, source = self.norm1(self.aggregate(x).permute(0,2,3,1)), self.norm1(self.max_pool(source).permute(0,2,3,1)) # [N, H, W, C]
if x_mask is not None:
x_mask, source_mask = map(lambda x: self.max_pool(x.float()).bool(), [x_mask, source_mask])
query, key, value = self.q_proj(query), self.k_proj(source), self.v_proj(source)
# Positional encoding
if self.rope:
query = self.rope_pos_enc(query)
key = self.rope_pos_enc(key)
# multi-head attention handle padding mask
m = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask)
m = self.merge(m.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
# Upsample feature
m = rearrange(m, 'b (h w) c -> b c h w', h=H0 // self.agg_size0, w=W0 // self.agg_size0) # [N, C, H0, W0]
if self.agg_size0 != 1:
m = torch.nn.functional.interpolate(m, scale_factor=self.agg_size0, mode='bilinear', align_corners=False) # [N, C, H0, W0]
# feed-forward network
m = self.mlp(torch.cat([x, m], dim=1).permute(0, 2, 3, 1)) # [N, H0, W0, C]
m = self.norm2(m).permute(0, 3, 1, 2) # [N, C, H0, W0]
return x + m
class LocalFeatureTransformer(nn.Module):
"""A Local Feature Transformer (LoFTR) module."""
def __init__(self, config):
super(LocalFeatureTransformer, self).__init__()
self.full_config = config
self.fp32 = not (config['mp'] or config['half'])
config = config['coarse']
self.d_model = config['d_model']
self.nhead = config['nhead']
self.layer_names = config['layer_names']
self.agg_size0, self.agg_size1 = config['agg_size0'], config['agg_size1']
self.rope = config['rope']
self_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['agg_size0'], config['agg_size1'],
config['no_flash'], config['rope'], config['npe'], self.fp32)
cross_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['agg_size0'], config['agg_size1'],
config['no_flash'], False, config['npe'], self.fp32)
self.layers = nn.ModuleList([copy.deepcopy(self_layer) if _ == 'self' else copy.deepcopy(cross_layer) for _ in self.layer_names])
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, feat0, feat1, mask0=None, mask1=None, data=None):
"""
Args:
feat0 (torch.Tensor): [N, C, H, W]
feat1 (torch.Tensor): [N, C, H, W]
mask0 (torch.Tensor): [N, L] (optional)
mask1 (torch.Tensor): [N, S] (optional)
"""
H0, W0, H1, W1 = feat0.size(-2), feat0.size(-1), feat1.size(-2), feat1.size(-1)
bs = feat0.shape[0]
feature_cropped = False
if bs == 1 and mask0 is not None and mask1 is not None:
mask_H0, mask_W0, mask_H1, mask_W1 = mask0.size(-2), mask0.size(-1), mask1.size(-2), mask1.size(-1)
mask_h0, mask_w0, mask_h1, mask_w1 = mask0[0].sum(-2)[0], mask0[0].sum(-1)[0], mask1[0].sum(-2)[0], mask1[0].sum(-1)[0]
mask_h0, mask_w0, mask_h1, mask_w1 = mask_h0//self.agg_size0*self.agg_size0, mask_w0//self.agg_size0*self.agg_size0, mask_h1//self.agg_size1*self.agg_size1, mask_w1//self.agg_size1*self.agg_size1
feat0 = feat0[:, :, :mask_h0, :mask_w0]
feat1 = feat1[:, :, :mask_h1, :mask_w1]
feature_cropped = True
for i, (layer, name) in enumerate(zip(self.layers, self.layer_names)):
if feature_cropped:
mask0, mask1 = None, None
if name == 'self':
feat0 = layer(feat0, feat0, mask0, mask0)
feat1 = layer(feat1, feat1, mask1, mask1)
elif name == 'cross':
feat0 = layer(feat0, feat1, mask0, mask1)
feat1 = layer(feat1, feat0, mask1, mask0)
else:
raise KeyError
if feature_cropped:
# padding feature
bs, c, mask_h0, mask_w0 = feat0.size()
if mask_h0 != mask_H0:
feat0 = torch.cat([feat0, torch.zeros(bs, c, mask_H0-mask_h0, mask_W0, device=feat0.device, dtype=feat0.dtype)], dim=-2)
elif mask_w0 != mask_W0:
feat0 = torch.cat([feat0, torch.zeros(bs, c, mask_H0, mask_W0-mask_w0, device=feat0.device, dtype=feat0.dtype)], dim=-1)
bs, c, mask_h1, mask_w1 = feat1.size()
if mask_h1 != mask_H1:
feat1 = torch.cat([feat1, torch.zeros(bs, c, mask_H1-mask_h1, mask_W1, device=feat1.device, dtype=feat1.dtype)], dim=-2)
elif mask_w1 != mask_W1:
feat1 = torch.cat([feat1, torch.zeros(bs, c, mask_H1, mask_W1-mask_w1, device=feat1.device, dtype=feat1.dtype)], dim=-1)
return feat0, feat1