Realcat's picture
add: rdd sparse and dense match
1b369eb
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..utils.misc import NestedTensor, nested_tensor_from_tensor_list
import torchvision.transforms as transforms
from .backbone import build_backbone
from .deformable_transformer import build_deforamble_transformer
class BasicLayer(nn.Module):
"""
Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU
"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False):
super().__init__()
self.layer = nn.Sequential(
nn.Conv2d( in_channels, out_channels, kernel_size, padding = padding, stride=stride, dilation=dilation, bias = bias),
nn.BatchNorm2d(out_channels, affine=False),
nn.ReLU(inplace = False),
)
def forward(self, x):
return self.layer(x)
class RDD_Descriptor(nn.Module):
def __init__(self, backbone, transformer, num_feature_levels):
super().__init__()
self.transformer = transformer
self.hidden_dim = transformer.d_model
self.num_feature_levels = num_feature_levels
self.matchibility_head = nn.Sequential(
BasicLayer(256, 128, 1, padding=0),
BasicLayer(128, 64, 1, padding=0),
nn.Conv2d (64, 1, 1),
nn.Sigmoid()
)
if num_feature_levels > 1:
num_backbone_outs = len(backbone.strides)
input_proj_list = []
for _ in range(num_backbone_outs):
in_channels = backbone.num_channels[_]
input_proj_list.append(nn.Sequential(
nn.Conv2d(in_channels, self.hidden_dim, kernel_size=1),
nn.GroupNorm(32, self.hidden_dim),
))
for _ in range(num_feature_levels - num_backbone_outs):
input_proj_list.append(nn.Sequential(
nn.Conv2d(in_channels, self.hidden_dim, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(32, self.hidden_dim),
))
in_channels = self.hidden_dim
self.input_proj = nn.ModuleList(input_proj_list)
else:
self.input_proj = nn.ModuleList([
nn.Sequential(
nn.Conv2d(backbone.num_channels[0], self.hidden_dim, kernel_size=1),
nn.GroupNorm(32, self.hidden_dim),
)])
self.backbone = backbone
self.stride = backbone.strides[0]
for proj in self.input_proj:
nn.init.xavier_uniform_(proj[0].weight, gain=1)
nn.init.constant_(proj[0].bias, 0)
def forward(self, samples: NestedTensor):
if not isinstance(samples, NestedTensor):
samples = nested_tensor_from_tensor_list(samples)
features, pos = self.backbone(samples)
srcs = []
masks = []
for l, feat in enumerate(features):
src, mask = feat.decompose()
srcs.append(self.input_proj[l](src))
masks.append(mask)
assert mask is not None
if self.num_feature_levels > len(srcs):
_len_srcs = len(srcs)
for l in range(_len_srcs, self.num_feature_levels):
if l == _len_srcs:
src = self.input_proj[l](features[-1].tensors)
else:
src = self.input_proj[l](srcs[-1])
m = samples.mask
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
srcs.append(src)
masks.append(mask)
pos.append(pos_l)
flatten_feats, spatial_shapes, level_start_index = self.transformer(srcs, masks, pos)
# Reshape the flattened features back to the original spatial shapes
feats = []
level_start_index = torch.cat((level_start_index, torch.tensor([flatten_feats.shape[1]+1]).to(level_start_index.device)))
for i, shape in enumerate(spatial_shapes):
assert len(shape) == 2
temp = flatten_feats[:, level_start_index[i] : level_start_index[i+1], :]
feats.append(temp.transpose(1, 2).view(-1, self.hidden_dim, *shape))
# Sum up the features from different levels
final_feature = feats[0]
for feat in feats[1:]:
final_feature = final_feature + F.interpolate(feat, size=final_feature.shape[-2:], mode='bilinear', align_corners=True)
matchibility = self.matchibility_head(final_feature)
return final_feature, matchibility
def build_descriptor(config):
backbone = build_backbone(config)
transformer = build_deforamble_transformer(config)
return RDD_Descriptor(backbone, transformer, config['num_feature_levels'])