import torch import torch.nn as nn import torch.nn.functional as F from ..utils.misc import NestedTensor, nested_tensor_from_tensor_list import torchvision.transforms as transforms from .backbone import build_backbone from .deformable_transformer import build_deforamble_transformer class BasicLayer(nn.Module): """ Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU """ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False): super().__init__() self.layer = nn.Sequential( nn.Conv2d( in_channels, out_channels, kernel_size, padding = padding, stride=stride, dilation=dilation, bias = bias), nn.BatchNorm2d(out_channels, affine=False), nn.ReLU(inplace = False), ) def forward(self, x): return self.layer(x) class RDD_Descriptor(nn.Module): def __init__(self, backbone, transformer, num_feature_levels): super().__init__() self.transformer = transformer self.hidden_dim = transformer.d_model self.num_feature_levels = num_feature_levels self.matchibility_head = nn.Sequential( BasicLayer(256, 128, 1, padding=0), BasicLayer(128, 64, 1, padding=0), nn.Conv2d (64, 1, 1), nn.Sigmoid() ) if num_feature_levels > 1: num_backbone_outs = len(backbone.strides) input_proj_list = [] for _ in range(num_backbone_outs): in_channels = backbone.num_channels[_] input_proj_list.append(nn.Sequential( nn.Conv2d(in_channels, self.hidden_dim, kernel_size=1), nn.GroupNorm(32, self.hidden_dim), )) for _ in range(num_feature_levels - num_backbone_outs): input_proj_list.append(nn.Sequential( nn.Conv2d(in_channels, self.hidden_dim, kernel_size=3, stride=2, padding=1), nn.GroupNorm(32, self.hidden_dim), )) in_channels = self.hidden_dim self.input_proj = nn.ModuleList(input_proj_list) else: self.input_proj = nn.ModuleList([ nn.Sequential( nn.Conv2d(backbone.num_channels[0], self.hidden_dim, kernel_size=1), nn.GroupNorm(32, self.hidden_dim), )]) self.backbone = backbone self.stride = backbone.strides[0] for proj in self.input_proj: nn.init.xavier_uniform_(proj[0].weight, gain=1) nn.init.constant_(proj[0].bias, 0) def forward(self, samples: NestedTensor): if not isinstance(samples, NestedTensor): samples = nested_tensor_from_tensor_list(samples) features, pos = self.backbone(samples) srcs = [] masks = [] for l, feat in enumerate(features): src, mask = feat.decompose() srcs.append(self.input_proj[l](src)) masks.append(mask) assert mask is not None if self.num_feature_levels > len(srcs): _len_srcs = len(srcs) for l in range(_len_srcs, self.num_feature_levels): if l == _len_srcs: src = self.input_proj[l](features[-1].tensors) else: src = self.input_proj[l](srcs[-1]) m = samples.mask mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) srcs.append(src) masks.append(mask) pos.append(pos_l) flatten_feats, spatial_shapes, level_start_index = self.transformer(srcs, masks, pos) # Reshape the flattened features back to the original spatial shapes feats = [] level_start_index = torch.cat((level_start_index, torch.tensor([flatten_feats.shape[1]+1]).to(level_start_index.device))) for i, shape in enumerate(spatial_shapes): assert len(shape) == 2 temp = flatten_feats[:, level_start_index[i] : level_start_index[i+1], :] feats.append(temp.transpose(1, 2).view(-1, self.hidden_dim, *shape)) # Sum up the features from different levels final_feature = feats[0] for feat in feats[1:]: final_feature = final_feature + F.interpolate(feat, size=final_feature.shape[-2:], mode='bilinear', align_corners=True) matchibility = self.matchibility_head(final_feature) return final_feature, matchibility def build_descriptor(config): backbone = build_backbone(config) transformer = build_deforamble_transformer(config) return RDD_Descriptor(backbone, transformer, config['num_feature_levels'])