|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from .backbone import CNNEncoder
|
|
from .transformer import FeatureTransformer, FeatureFlowAttention
|
|
from .matching import global_correlation_softmax, local_correlation_softmax
|
|
from .geometry import flow_warp
|
|
from .utils import normalize_img, feature_add_position
|
|
|
|
|
|
class GMFlow(nn.Module):
|
|
def __init__(self,
|
|
num_scales=1,
|
|
upsample_factor=8,
|
|
feature_channels=128,
|
|
attention_type='swin',
|
|
num_transformer_layers=6,
|
|
ffn_dim_expansion=4,
|
|
num_head=1,
|
|
**kwargs,
|
|
):
|
|
super(GMFlow, self).__init__()
|
|
|
|
self.num_scales = num_scales
|
|
self.feature_channels = feature_channels
|
|
self.upsample_factor = upsample_factor
|
|
self.attention_type = attention_type
|
|
self.num_transformer_layers = num_transformer_layers
|
|
|
|
|
|
self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales)
|
|
|
|
|
|
self.transformer = FeatureTransformer(num_layers=num_transformer_layers,
|
|
d_model=feature_channels,
|
|
nhead=num_head,
|
|
attention_type=attention_type,
|
|
ffn_dim_expansion=ffn_dim_expansion,
|
|
)
|
|
|
|
|
|
self.feature_flow_attn = FeatureFlowAttention(in_channels=feature_channels)
|
|
|
|
|
|
self.upsampler = nn.Sequential(nn.Conv2d(2 + feature_channels, 256, 3, 1, 1),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(256, upsample_factor ** 2 * 9, 1, 1, 0))
|
|
|
|
def extract_feature(self, img0, img1):
|
|
concat = torch.cat((img0, img1), dim=0)
|
|
features = self.backbone(concat)
|
|
|
|
|
|
features = features[::-1]
|
|
|
|
feature0, feature1 = [], []
|
|
|
|
for i in range(len(features)):
|
|
feature = features[i]
|
|
chunks = torch.chunk(feature, 2, 0)
|
|
feature0.append(chunks[0])
|
|
feature1.append(chunks[1])
|
|
|
|
return feature0, feature1
|
|
|
|
def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8,
|
|
):
|
|
if bilinear:
|
|
up_flow = F.interpolate(flow, scale_factor=upsample_factor,
|
|
mode='bilinear', align_corners=True) * upsample_factor
|
|
|
|
else:
|
|
|
|
concat = torch.cat((flow, feature), dim=1)
|
|
|
|
mask = self.upsampler(concat)
|
|
b, flow_channel, h, w = flow.shape
|
|
mask = mask.view(b, 1, 9, self.upsample_factor, self.upsample_factor, h, w)
|
|
mask = torch.softmax(mask, dim=2)
|
|
|
|
up_flow = F.unfold(self.upsample_factor * flow, [3, 3], padding=1)
|
|
up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w)
|
|
|
|
up_flow = torch.sum(mask * up_flow, dim=2)
|
|
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
|
up_flow = up_flow.reshape(b, flow_channel, self.upsample_factor * h,
|
|
self.upsample_factor * w)
|
|
|
|
return up_flow
|
|
|
|
def forward(self, img0, img1,
|
|
attn_splits_list=None,
|
|
corr_radius_list=None,
|
|
prop_radius_list=None,
|
|
pred_bidir_flow=False,
|
|
**kwargs,
|
|
):
|
|
|
|
results_dict = {}
|
|
flow_preds = []
|
|
|
|
img0, img1 = normalize_img(img0, img1)
|
|
|
|
|
|
feature0_list, feature1_list = self.extract_feature(img0, img1)
|
|
|
|
flow = None
|
|
|
|
assert len(attn_splits_list) == len(corr_radius_list) == len(prop_radius_list) == self.num_scales
|
|
|
|
for scale_idx in range(self.num_scales):
|
|
feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
|
|
|
|
if pred_bidir_flow and scale_idx > 0:
|
|
|
|
feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0)
|
|
|
|
upsample_factor = self.upsample_factor * (2 ** (self.num_scales - 1 - scale_idx))
|
|
|
|
if scale_idx > 0:
|
|
flow = F.interpolate(flow, scale_factor=2, mode='bilinear', align_corners=True) * 2
|
|
|
|
if flow is not None:
|
|
flow = flow.detach()
|
|
feature1 = flow_warp(feature1, flow)
|
|
|
|
attn_splits = attn_splits_list[scale_idx]
|
|
corr_radius = corr_radius_list[scale_idx]
|
|
prop_radius = prop_radius_list[scale_idx]
|
|
|
|
|
|
feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels)
|
|
|
|
|
|
feature0, feature1 = self.transformer(feature0, feature1, attn_num_splits=attn_splits)
|
|
|
|
|
|
if corr_radius == -1:
|
|
flow_pred = global_correlation_softmax(feature0, feature1, pred_bidir_flow)[0]
|
|
else:
|
|
flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[0]
|
|
|
|
|
|
flow = flow + flow_pred if flow is not None else flow_pred
|
|
|
|
|
|
if self.training:
|
|
flow_bilinear = self.upsample_flow(flow, None, bilinear=True, upsample_factor=upsample_factor)
|
|
flow_preds.append(flow_bilinear)
|
|
|
|
|
|
if pred_bidir_flow and scale_idx == 0:
|
|
feature0 = torch.cat((feature0, feature1), dim=0)
|
|
flow = self.feature_flow_attn(feature0, flow.detach(),
|
|
local_window_attn=prop_radius > 0,
|
|
local_window_radius=prop_radius)
|
|
|
|
|
|
if self.training and scale_idx < self.num_scales - 1:
|
|
flow_up = self.upsample_flow(flow, feature0, bilinear=True, upsample_factor=upsample_factor)
|
|
flow_preds.append(flow_up)
|
|
|
|
if scale_idx == self.num_scales - 1:
|
|
flow_up = self.upsample_flow(flow, feature0)
|
|
flow_preds.append(flow_up)
|
|
|
|
results_dict.update({'flow_preds': flow_preds})
|
|
|
|
return results_dict
|
|
|