Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. | |
import math | |
import fvcore.nn.weight_init as weight_init | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from detectron2.layers import Conv2d, ShapeSpec, get_norm | |
from .backbone import Backbone | |
from .build import BACKBONE_REGISTRY | |
from .resnet import build_resnet_backbone | |
from timm.models.layers import DropPath, to_2tuple, trunc_normal_ | |
import numpy as np | |
__all__ = ["build_resnet_fpn_backbone", "build_retinanet_resnet_fpn_backbone", "FPN"] | |
class Mlp(nn.Module): | |
""" Multilayer perceptron.""" | |
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): | |
super().__init__() | |
out_features = out_features or in_features | |
hidden_features = hidden_features or in_features | |
self.fc1 = nn.Linear(in_features, hidden_features) | |
self.act = act_layer() | |
self.fc2 = nn.Linear(hidden_features, out_features) | |
self.drop = nn.Dropout(drop) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.act(x) | |
x = self.drop(x) | |
x = self.fc2(x) | |
x = self.drop(x) | |
return x | |
def window_partition(x, window_size): | |
""" | |
Args: | |
x: (B, H, W, C) | |
window_size (int): window size | |
Returns: | |
windows: (num_windows*B, window_size, window_size, C) | |
""" | |
B, H, W, C = x.shape | |
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) | |
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) | |
return windows | |
def window_reverse(windows, window_size, H, W): | |
""" | |
Args: | |
windows: (num_windows*B, window_size, window_size, C) | |
window_size (int): Window size | |
H (int): Height of image | |
W (int): Width of image | |
Returns: | |
x: (B, H, W, C) | |
""" | |
B = int(windows.shape[0] / (H * W / window_size / window_size)) | |
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) | |
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) | |
return x | |
class WindowAttention(nn.Module): | |
""" Window based multi-head self attention (W-MSA) module with relative position bias. | |
It supports both of shifted and non-shifted window. | |
Args: | |
dim (int): Number of input channels. | |
window_size (tuple[int]): The height and width of the window. | |
num_heads (int): Number of attention heads. | |
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set | |
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 | |
proj_drop (float, optional): Dropout ratio of output. Default: 0.0 | |
""" | |
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): | |
super().__init__() | |
self.dim = dim | |
self.window_size = window_size # Wh, Ww | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
self.scale = qk_scale or head_dim ** -0.5 | |
# define a parameter table of relative position bias | |
self.relative_position_bias_table = nn.Parameter( | |
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH | |
# get pair-wise relative position index for each token inside the window | |
coords_h = torch.arange(self.window_size[0]) | |
coords_w = torch.arange(self.window_size[1]) | |
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww | |
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww | |
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww | |
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 | |
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 | |
relative_coords[:, :, 1] += self.window_size[1] - 1 | |
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 | |
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww | |
self.register_buffer("relative_position_index", relative_position_index) | |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
trunc_normal_(self.relative_position_bias_table, std=.02) | |
self.softmax = nn.Softmax(dim=-1) | |
def forward(self, x, mask=None): | |
""" Forward function. | |
Args: | |
x: input features with shape of (num_windows*B, N, C) | |
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None | |
""" | |
B_, N, C = x.shape | |
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) | |
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) | |
q = q * self.scale | |
attn = (q @ k.transpose(-2, -1)) | |
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( | |
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH | |
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww | |
attn = attn + relative_position_bias.unsqueeze(0) | |
if mask is not None: | |
nW = mask.shape[0] | |
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) | |
attn = attn.view(-1, self.num_heads, N, N) | |
attn = self.softmax(attn) | |
else: | |
attn = self.softmax(attn) | |
attn = self.attn_drop(attn) | |
x = (attn @ v).transpose(1, 2).reshape(B_, N, C) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
class SwinTransformerBlock(nn.Module): | |
""" Swin Transformer Block. | |
Args: | |
dim (int): Number of input channels. | |
num_heads (int): Number of attention heads. | |
window_size (int): Window size. | |
shift_size (int): Shift size for SW-MSA. | |
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. | |
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. | |
drop (float, optional): Dropout rate. Default: 0.0 | |
attn_drop (float, optional): Attention dropout rate. Default: 0.0 | |
drop_path (float, optional): Stochastic depth rate. Default: 0.0 | |
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU | |
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm | |
""" | |
def __init__(self, dim, num_heads, window_size=7, shift_size=0, | |
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., | |
act_layer=nn.GELU, norm_layer=nn.LayerNorm): | |
super().__init__() | |
self.dim = dim | |
self.num_heads = num_heads | |
self.window_size = window_size | |
self.shift_size = shift_size | |
self.mlp_ratio = mlp_ratio | |
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" | |
self.norm1 = norm_layer(dim) | |
self.attn = WindowAttention( | |
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, | |
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) | |
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
self.norm2 = norm_layer(dim) | |
mlp_hidden_dim = int(dim * mlp_ratio) | |
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) | |
self.H = None | |
self.W = None | |
def forward(self, x, mask_matrix): | |
""" Forward function. | |
Args: | |
x: Input feature, tensor size (B, H*W, C). | |
H, W: Spatial resolution of the input feature. | |
mask_matrix: Attention mask for cyclic shift. | |
""" | |
B, L, C = x.shape | |
H, W = self.H, self.W | |
assert L == H * W, "input feature has wrong size" | |
shortcut = x | |
x = self.norm1(x) | |
x = x.view(B, H, W, C) | |
# pad feature maps to multiples of window size | |
pad_l = pad_t = 0 | |
pad_r = (self.window_size - W % self.window_size) % self.window_size | |
pad_b = (self.window_size - H % self.window_size) % self.window_size | |
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) | |
_, Hp, Wp, _ = x.shape | |
# cyclic shift | |
if self.shift_size > 0: | |
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) | |
attn_mask = mask_matrix | |
else: | |
shifted_x = x | |
attn_mask = None | |
# partition windows | |
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C | |
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C | |
# W-MSA/SW-MSA | |
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C | |
# merge windows | |
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) | |
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C | |
# reverse cyclic shift | |
if self.shift_size > 0: | |
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) | |
else: | |
x = shifted_x | |
if pad_r > 0 or pad_b > 0: | |
x = x[:, :H, :W, :].contiguous() | |
x = x.view(B, H * W, C) | |
# FFN | |
x = shortcut + self.drop_path(x) | |
x = x + self.drop_path(self.mlp(self.norm2(x))) | |
return x | |
class swin_layer(nn.Module): | |
""" A basic Swin Transformer layer for one stage. | |
Args: | |
dim (int): Number of feature channels | |
depth (int): Depths of this stage. | |
num_heads (int): Number of attention head. | |
window_size (int): Local window size. Default: 7. | |
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. | |
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. | |
drop (float, optional): Dropout rate. Default: 0.0 | |
attn_drop (float, optional): Attention dropout rate. Default: 0.0 | |
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 | |
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm | |
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None | |
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. | |
""" | |
def __init__(self, | |
dim, | |
depth, | |
num_heads, | |
window_size=7, | |
mlp_ratio=4., | |
qkv_bias=True, | |
qk_scale=None, | |
drop=0., | |
attn_drop=0., | |
drop_path=0., | |
norm_layer=nn.LayerNorm, | |
downsample=None, | |
use_checkpoint=False): | |
super().__init__() | |
self.window_size = window_size | |
self.shift_size = window_size // 2 | |
self.depth = depth | |
self.use_checkpoint = use_checkpoint | |
# build blocks | |
self.blocks = nn.ModuleList([ | |
SwinTransformerBlock( | |
dim=dim, | |
num_heads=num_heads, | |
window_size=window_size, | |
shift_size=0 if (i % 2 == 0) else window_size // 2, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
drop=drop, | |
attn_drop=attn_drop, | |
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, | |
norm_layer=norm_layer) | |
for i in range(depth)]) | |
# patch merging layer | |
if downsample is not None: | |
self.downsample = downsample(dim=dim, norm_layer=norm_layer) | |
else: | |
self.downsample = None | |
def forward(self, x, H, W): | |
""" Forward function. | |
Args: | |
x: Input feature, tensor size (B, H*W, C). | |
H, W: Spatial resolution of the input feature. | |
""" | |
# calculate attention mask for SW-MSA | |
Hp = int(np.ceil(H / self.window_size)) * self.window_size | |
Wp = int(np.ceil(W / self.window_size)) * self.window_size | |
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 | |
h_slices = (slice(0, -self.window_size), | |
slice(-self.window_size, -self.shift_size), | |
slice(-self.shift_size, None)) | |
w_slices = (slice(0, -self.window_size), | |
slice(-self.window_size, -self.shift_size), | |
slice(-self.shift_size, None)) | |
cnt = 0 | |
for h in h_slices: | |
for w in w_slices: | |
img_mask[:, h, w, :] = cnt | |
cnt += 1 | |
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 | |
mask_windows = mask_windows.view(-1, self.window_size * self.window_size) | |
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) | |
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) | |
for blk in self.blocks: | |
blk.H, blk.W = H, W | |
if self.use_checkpoint: | |
x = checkpoint.checkpoint(blk, x, attn_mask) | |
else: | |
x = blk(x, attn_mask) | |
if self.downsample is not None: | |
x_down = self.downsample(x, H, W) | |
Wh, Ww = (H + 1) // 2, (W + 1) // 2 | |
return x, H, W, x_down, Wh, Ww | |
else: | |
return x, H, W, x, H, W | |
class FPN(Backbone): | |
""" | |
This module implements :paper:`FPN`. | |
It creates pyramid features built on top of some input feature maps. | |
""" | |
_fuse_type: torch.jit.Final[str] | |
def __init__( | |
self, bottom_up, in_features, out_channels, norm="", top_block=None, fuse_type="sum" | |
): | |
""" | |
Args: | |
bottom_up (Backbone): module representing the bottom up subnetwork. | |
Must be a subclass of :class:`Backbone`. The multi-scale feature | |
maps generated by the bottom up network, and listed in `in_features`, | |
are used to generate FPN levels. | |
in_features (list[str]): names of the input feature maps coming | |
from the backbone to which FPN is attached. For example, if the | |
backbone produces ["res2", "res3", "res4"], any *contiguous* sublist | |
of these may be used; order must be from high to low resolution. | |
out_channels (int): number of channels in the output feature maps. | |
norm (str): the normalization to use. | |
top_block (nn.Module or None): if provided, an extra operation will | |
be performed on the output of the last (smallest resolution) | |
FPN output, and the result will extend the result list. The top_block | |
further downsamples the feature map. It must have an attribute | |
"num_levels", meaning the number of extra FPN levels added by | |
this block, and "in_feature", which is a string representing | |
its input feature (e.g., p5). | |
fuse_type (str): types for fusing the top down features and the lateral | |
ones. It can be "sum" (default), which sums up element-wise; or "avg", | |
which takes the element-wise mean of the two. | |
""" | |
super(FPN, self).__init__() | |
assert isinstance(bottom_up, Backbone) | |
assert in_features, in_features | |
# Feature map strides and channels from the bottom up network (e.g. ResNet) | |
input_shapes = bottom_up.output_shape() | |
strides = [input_shapes[f].stride for f in in_features] | |
in_channels_per_feature = [input_shapes[f].channels for f in in_features] | |
_assert_strides_are_log2_contiguous(strides) | |
lateral_convs = [] | |
output_convs = [] | |
use_bias = norm == "" | |
for idx, in_channels in enumerate(in_channels_per_feature): | |
lateral_norm = get_norm(norm, out_channels) | |
output_norm = get_norm(norm, out_channels) | |
lateral_conv = Conv2d( | |
in_channels, out_channels, kernel_size=1, bias=use_bias, norm=lateral_norm | |
) | |
# output_conv = Conv2d( | |
# out_channels, | |
# out_channels, | |
# kernel_size=3, | |
# stride=1, | |
# padding=1, | |
# bias=use_bias, | |
# norm=output_norm, | |
# ) | |
output_conv = swin_layer( dim=out_channels, | |
depth=1, | |
num_heads=2, | |
window_size=7) | |
self.out_channels = out_channels | |
weight_init.c2_xavier_fill(lateral_conv) | |
# weight_init.c2_xavier_fill(output_conv) | |
stage = int(math.log2(strides[idx])) | |
self.add_module("fpn_lateral{}".format(stage), lateral_conv) | |
self.add_module("fpn_output{}".format(stage), output_conv) | |
lateral_convs.append(lateral_conv) | |
output_convs.append(output_conv) | |
# Place convs into top-down order (from low to high resolution) | |
# to make the top-down computation in forward clearer. | |
self.lateral_convs = lateral_convs[::-1] | |
self.output_convs = output_convs[::-1] | |
self.top_block = top_block | |
self.in_features = tuple(in_features) | |
self.bottom_up = bottom_up | |
# Return feature names are "p<stage>", like ["p2", "p3", ..., "p6"] | |
self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides} | |
# top block output feature maps. | |
if self.top_block is not None: | |
for s in range(stage, stage + self.top_block.num_levels): | |
self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1) | |
self._out_features = list(self._out_feature_strides.keys()) | |
self._out_feature_channels = {k: out_channels for k in self._out_features} | |
self._size_divisibility = strides[-1] | |
assert fuse_type in {"avg", "sum"} | |
self._fuse_type = fuse_type | |
def size_divisibility(self): | |
return self._size_divisibility | |
def forward(self, x): | |
""" | |
Args: | |
input (dict[str->Tensor]): mapping feature map name (e.g., "res5") to | |
feature map tensor for each feature level in high to low resolution order. | |
Returns: | |
dict[str->Tensor]: | |
mapping from feature map name to FPN feature map tensor | |
in high to low resolution order. Returned feature names follow the FPN | |
paper convention: "p<stage>", where stage has stride = 2 ** stage e.g., | |
["p2", "p3", ..., "p6"]. | |
""" | |
bottom_up_features = self.bottom_up(x) | |
results = [] | |
prev_features = self.lateral_convs[0](bottom_up_features[self.in_features[-1]]) | |
B, C, Wh, Ww = prev_features.size() | |
prev_features = prev_features.flatten(2).transpose(1, 2) | |
x_out, H, W, x, Wh, Ww = self.output_convs[0](prev_features, Wh, Ww) | |
prev_features = x_out.transpose(1, 2).view(-1, self.out_channels, Wh, Ww) | |
results.append(prev_features) | |
# Reverse feature maps into top-down order (from low to high resolution) | |
for idx, (lateral_conv, output_conv) in enumerate( | |
zip(self.lateral_convs, self.output_convs) | |
): | |
# Slicing of ModuleList is not supported https://github.com/pytorch/pytorch/issues/47336 | |
# Therefore we loop over all modules but skip the first one | |
if idx > 0: | |
features = self.in_features[-idx - 1] | |
features = bottom_up_features[features] | |
top_down_features = F.interpolate(prev_features, scale_factor=2.0, mode="nearest") | |
lateral_features = lateral_conv(features) | |
prev_features = lateral_features + top_down_features | |
if self._fuse_type == "avg": | |
prev_features /= 2 | |
B, C, Wh, Ww = prev_features.size() | |
prev_features = prev_features.flatten(2).transpose(1, 2) | |
x_out, H, W, x, Wh, Ww = self.output_convs[0](prev_features, Wh, Ww) | |
prev_features = x_out.transpose(1, 2).view(-1, self.out_channels, Wh, Ww) | |
results.insert(0, prev_features) | |
if self.top_block is not None: | |
if self.top_block.in_feature in bottom_up_features: | |
top_block_in_feature = bottom_up_features[self.top_block.in_feature] | |
else: | |
top_block_in_feature = results[self._out_features.index(self.top_block.in_feature)] | |
results.extend(self.top_block(top_block_in_feature)) | |
assert len(self._out_features) == len(results) | |
return {f: res for f, res in zip(self._out_features, results)} | |
def output_shape(self): | |
return { | |
name: ShapeSpec( | |
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] | |
) | |
for name in self._out_features | |
} | |
def _assert_strides_are_log2_contiguous(strides): | |
""" | |
Assert that each stride is 2x times its preceding stride, i.e. "contiguous in log2". | |
""" | |
for i, stride in enumerate(strides[1:], 1): | |
assert stride == 2 * strides[i - 1], "Strides {} {} are not log2 contiguous".format( | |
stride, strides[i - 1] | |
) | |
class LastLevelMaxPool(nn.Module): | |
""" | |
This module is used in the original FPN to generate a downsampled | |
P6 feature from P5. | |
""" | |
def __init__(self): | |
super().__init__() | |
self.num_levels = 1 | |
self.in_feature = "p5" | |
def forward(self, x): | |
return [F.max_pool2d(x, kernel_size=1, stride=2, padding=0)] | |
class LastLevelP6P7(nn.Module): | |
""" | |
This module is used in RetinaNet to generate extra layers, P6 and P7 from | |
C5 feature. | |
""" | |
def __init__(self, in_channels, out_channels, in_feature="res5"): | |
super().__init__() | |
self.num_levels = 2 | |
self.in_feature = in_feature | |
self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) | |
self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1) | |
for module in [self.p6, self.p7]: | |
weight_init.c2_xavier_fill(module) | |
def forward(self, c5): | |
p6 = self.p6(c5) | |
p7 = self.p7(F.relu(p6)) | |
return [p6, p7] | |
def build_resnet_fpn_backbone(cfg, input_shape: ShapeSpec): | |
""" | |
Args: | |
cfg: a detectron2 CfgNode | |
Returns: | |
backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. | |
""" | |
bottom_up = build_resnet_backbone(cfg, input_shape) | |
in_features = cfg.MODEL.FPN.IN_FEATURES | |
out_channels = cfg.MODEL.FPN.OUT_CHANNELS | |
backbone = FPN( | |
bottom_up=bottom_up, | |
in_features=in_features, | |
out_channels=out_channels, | |
norm=cfg.MODEL.FPN.NORM, | |
top_block=LastLevelMaxPool(), | |
fuse_type=cfg.MODEL.FPN.FUSE_TYPE, | |
) | |
return backbone | |
def build_retinanet_resnet_fpn_backbone(cfg, input_shape: ShapeSpec): | |
""" | |
Args: | |
cfg: a detectron2 CfgNode | |
Returns: | |
backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. | |
""" | |
bottom_up = build_resnet_backbone(cfg, input_shape) | |
in_features = cfg.MODEL.FPN.IN_FEATURES | |
out_channels = cfg.MODEL.FPN.OUT_CHANNELS | |
in_channels_p6p7 = bottom_up.output_shape()["res5"].channels | |
backbone = FPN( | |
bottom_up=bottom_up, | |
in_features=in_features, | |
out_channels=out_channels, | |
norm=cfg.MODEL.FPN.NORM, | |
top_block=LastLevelP6P7(in_channels_p6p7, out_channels), | |
fuse_type=cfg.MODEL.FPN.FUSE_TYPE, | |
) | |
return backbone | |