Spaces:
Running
Running
"""This code is refer from: | |
https://github.com/MelosY/CAM | |
""" | |
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.init import trunc_normal_ | |
from .convnextv2 import ConvNeXtV2, Block, LayerNorm | |
from .svtrv2_lnconv_two33 import SVTRv2LNConvTwo33 | |
class Swish(nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
def forward(self, x): | |
return x * torch.sigmoid(x) | |
class UNetBlock(nn.Module): | |
def __init__(self, cin, cout, bn2d, stride, deformable=False): | |
""" | |
a UNet block with 2x up sampling | |
""" | |
super().__init__() | |
stride_h, stride_w = stride | |
if stride_h == 1: | |
kernel_h = 1 | |
padding_h = 0 | |
elif stride_h == 2: | |
kernel_h = 4 | |
padding_h = 1 | |
elif stride_h == 4: | |
kernel_h = 4 | |
padding_h = 0 | |
if stride_w == 1: | |
kernel_w = 1 | |
padding_w = 0 | |
elif stride_w == 2: | |
kernel_w = 4 | |
padding_w = 1 | |
elif stride_w == 4: | |
kernel_w = 4 | |
padding_w = 0 | |
conv = nn.Conv2d | |
self.up_sample = nn.ConvTranspose2d(cin, | |
cin, | |
kernel_size=(kernel_h, kernel_w), | |
stride=(stride_h, stride_w), | |
padding=(padding_h, padding_w), | |
bias=True) | |
self.conv = nn.Sequential( | |
conv(cin, cin, kernel_size=3, stride=1, padding=1, bias=False), | |
bn2d(cin), | |
nn.ReLU6(inplace=True), | |
conv(cin, cout, kernel_size=3, stride=1, padding=1, bias=False), | |
bn2d(cout), | |
) | |
def forward(self, x): | |
x = self.up_sample(x) | |
return self.conv(x) | |
class DepthWiseUNetBlock(nn.Module): | |
def __init__(self, cin, cout, bn2d, stride, deformable=False): | |
""" | |
a UNet block with 2x up sampling | |
""" | |
super().__init__() | |
stride_h, stride_w = stride | |
if stride_h == 1: | |
kernel_h = 1 | |
padding_h = 0 | |
elif stride_h == 2: | |
kernel_h = 4 | |
padding_h = 1 | |
elif stride_h == 4: | |
kernel_h = 4 | |
padding_h = 0 | |
if stride_w == 1: | |
kernel_w = 1 | |
padding_w = 0 | |
elif stride_w == 2: | |
kernel_w = 4 | |
padding_w = 1 | |
elif stride_w == 4: | |
kernel_w = 4 | |
padding_w = 0 | |
self.up_sample = nn.ConvTranspose2d(cin, | |
cin, | |
kernel_size=(kernel_h, kernel_w), | |
stride=(stride_h, stride_w), | |
padding=(padding_h, padding_w), | |
bias=True) | |
self.conv = nn.Sequential( | |
nn.Conv2d(cin, | |
cin, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=False, | |
groups=cin), | |
nn.Conv2d(cin, cin, kernel_size=1, stride=1, padding=0, | |
bias=False), | |
bn2d(cin), | |
nn.ReLU6(inplace=True), | |
nn.Conv2d(cin, | |
cin, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=False, | |
groups=cin), | |
nn.Conv2d(cin, | |
cout, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=False), | |
bn2d(cout), | |
) | |
def forward(self, x): | |
x = self.up_sample(x) | |
return self.conv(x) | |
class SFTLayer(nn.Module): | |
def __init__(self, dim_in, dim_out): | |
super(SFTLayer, self).__init__() | |
self.SFT_scale_conv0 = nn.Linear( | |
dim_in, | |
dim_in, | |
) | |
self.SFT_scale_conv1 = nn.Linear( | |
dim_in, | |
dim_out, | |
) | |
self.SFT_shift_conv0 = nn.Linear( | |
dim_in, | |
dim_in, | |
) | |
self.SFT_shift_conv1 = nn.Linear( | |
dim_in, | |
dim_out, | |
) | |
def forward(self, x): | |
# x[0]: fea; x[1]: cond | |
scale = self.SFT_scale_conv1( | |
F.leaky_relu(self.SFT_scale_conv0(x[1]), 0.1, inplace=True)) | |
shift = self.SFT_shift_conv1( | |
F.leaky_relu(self.SFT_shift_conv0(x[1]), 0.1, inplace=True)) | |
return x[0] * (scale + 1) + shift | |
class MoreUNetBlock(nn.Module): | |
def __init__(self, cin, cout, bn2d, stride, deformable=False): | |
""" | |
a UNet block with 2x up sampling | |
""" | |
super().__init__() | |
stride_h, stride_w = stride | |
if stride_h == 1: | |
kernel_h = 1 | |
padding_h = 0 | |
elif stride_h == 2: | |
kernel_h = 4 | |
padding_h = 1 | |
elif stride_h == 4: | |
kernel_h = 4 | |
padding_h = 0 | |
if stride_w == 1: | |
kernel_w = 1 | |
padding_w = 0 | |
elif stride_w == 2: | |
kernel_w = 4 | |
padding_w = 1 | |
elif stride_w == 4: | |
kernel_w = 4 | |
padding_w = 0 | |
self.up_sample = nn.ConvTranspose2d(cin, | |
cin, | |
kernel_size=(kernel_h, kernel_w), | |
stride=(stride_h, stride_w), | |
padding=(padding_h, padding_w), | |
bias=True) | |
self.conv = nn.Sequential( | |
nn.Conv2d(cin, | |
cin, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=False, | |
groups=cin), | |
nn.Conv2d(cin, cin, kernel_size=1, stride=1, padding=0, | |
bias=False), bn2d(cin), nn.ReLU6(inplace=True), | |
nn.Conv2d(cin, | |
cin, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=False, | |
groups=cin), | |
nn.Conv2d(cin, cin, kernel_size=1, stride=1, padding=0, | |
bias=False), bn2d(cin), nn.ReLU6(inplace=True), | |
nn.Conv2d(cin, | |
cin, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=False, | |
groups=cin), | |
nn.Conv2d(cin, | |
cout, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=False), bn2d(cout), nn.ReLU6(inplace=True), | |
nn.Conv2d(cout, | |
cout, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=False, | |
groups=cout), | |
nn.Conv2d(cout, | |
cout, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=False), bn2d(cout)) | |
def forward(self, x): | |
x = self.up_sample(x) | |
return self.conv(x) | |
class BinaryDecoder(nn.Module): | |
def __init__(self, | |
dim, | |
num_classes, | |
strides, | |
use_depthwise_unet=False, | |
use_more_unet=False, | |
binary_loss_type='DiceLoss') -> None: | |
super().__init__() | |
channels = [dim // 2**i for i in range(4)] | |
self.linear_enc2binary = nn.Sequential( | |
nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1), | |
nn.SyncBatchNorm(dim), | |
) | |
self.strides = strides | |
self.use_deformable = False | |
self.binary_decoder = nn.ModuleList() | |
unet = DepthWiseUNetBlock if use_depthwise_unet else UNetBlock | |
unet = MoreUNetBlock if use_more_unet else unet | |
for i in range(3): | |
up_sample_stride = self.strides[::-1][i] | |
cin, cout = channels[i], channels[i + 1] | |
self.binary_decoder.append( | |
unet(cin, cout, nn.SyncBatchNorm, up_sample_stride, | |
self.use_deformable)) | |
last_stride = (self.strides[0][0] // 2, self.strides[0][1] // 2) | |
self.binary_decoder.append( | |
unet(cout, cout, nn.SyncBatchNorm, last_stride, | |
self.use_deformable)) | |
if binary_loss_type == 'CrossEntropyDiceLoss' or binary_loss_type == 'BanlanceMultiClassCrossEntropyLoss': | |
segm_num_cls = num_classes - 2 | |
else: | |
segm_num_cls = num_classes - 3 | |
self.binary_pred = nn.Conv2d(channels[-1], | |
segm_num_cls, | |
kernel_size=1, | |
stride=1, | |
bias=True) | |
def patchify(self, imgs): | |
""" | |
imgs: (N, 3, H, W) | |
x: (N, L, patch_size**2 *3) | |
""" | |
p_h, p_w = self.strides[0] | |
p_h = p_h // 2 | |
p_w = p_w // 2 | |
h = imgs.shape[2] // p_h | |
w = imgs.shape[3] // p_w | |
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p_h, w, p_w)) | |
x = torch.einsum('nchpwq->nhwpqc', x) | |
x = x.reshape(shape=(imgs.shape[0], h * w, p_h * p_w * 3)) | |
return x | |
def unpatchify(self, x): | |
""" | |
x: (N, patch_size**2, h, w) | |
imgs: (N, 3, H, W) | |
""" | |
p_h, p_w = self.strides[0] | |
p_h = p_h // 2 | |
p_w = p_w // 2 | |
_, _, h, w = x.shape | |
assert p_h * p_w == x.shape[1] | |
x = x.permute(0, 2, 3, 1) # [N, h, w, 4*4] | |
x = x.reshape(shape=(x.shape[0], h, w, p_h, p_w)) | |
x = torch.einsum('nhwpq->nhpwq', x) | |
imgs = x.reshape(shape=(x.shape[0], h * p_h, w * p_w)) | |
return imgs | |
def forward(self, x, time=None): | |
""" | |
x: the encoder feat to init the query for binary prediction, usually this is equal to the `img`. | |
img: the encoder feat. | |
txt: the unnormmed text to get the length of predicted words. | |
txt_feat: the text feat before character prediction. | |
xs: the encoder feat from different stages | |
""" | |
binary_feats = [] | |
x = self.linear_enc2binary(x) | |
binary_feats.append(x.clone()) | |
for i, d in enumerate(self.binary_decoder): | |
x = d(x) | |
binary_feats.append(x.clone()) | |
#return None,binary_feats | |
x = self.binary_pred(x) | |
if self.training: | |
return x, binary_feats | |
else: | |
# return torch.sigmoid(x), binary_feat | |
return x.softmax(1), binary_feats | |
class LayerNormProxy(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.norm = nn.LayerNorm(dim) | |
def forward(self, x): | |
x = x.permute(0, 2, 3, 1) | |
x = self.norm(x) | |
return x.permute(0, 3, 1, 2) | |
class DAttentionFuse(nn.Module): | |
def __init__( | |
self, | |
q_size=(4, 32), | |
kv_size=(4, 32), | |
n_heads=8, | |
n_head_channels=80, | |
n_groups=4, | |
attn_drop=0.0, | |
proj_drop=0.0, | |
stride=2, | |
offset_range_factor=2, | |
use_pe=True, | |
stage_idx=0, | |
): | |
''' | |
stage_idx from 2 to 3 | |
''' | |
super().__init__() | |
self.n_head_channels = n_head_channels | |
self.scale = self.n_head_channels**-0.5 | |
self.n_heads = n_heads | |
self.q_h, self.q_w = q_size | |
self.kv_h, self.kv_w = kv_size | |
self.nc = n_head_channels * n_heads | |
self.n_groups = n_groups | |
self.n_group_channels = self.nc // self.n_groups | |
self.n_group_heads = self.n_heads // self.n_groups | |
self.use_pe = use_pe | |
self.offset_range_factor = offset_range_factor | |
ksizes = [9, 7, 5, 3] | |
kk = ksizes[stage_idx] | |
self.conv_offset = nn.Sequential( | |
nn.Conv2d(2 * self.n_group_channels, | |
2 * self.n_group_channels, | |
kk, | |
stride, | |
kk // 2, | |
groups=self.n_group_channels), | |
LayerNormProxy(2 * self.n_group_channels), nn.GELU(), | |
nn.Conv2d(2 * self.n_group_channels, 2, 1, 1, 0, bias=False)) | |
self.proj_q = nn.Conv2d(self.nc, | |
self.nc, | |
kernel_size=1, | |
stride=1, | |
padding=0) | |
self.proj_k = nn.Conv2d(self.nc, | |
self.nc, | |
kernel_size=1, | |
stride=1, | |
padding=0) | |
self.proj_v = nn.Conv2d(self.nc, | |
self.nc, | |
kernel_size=1, | |
stride=1, | |
padding=0) | |
self.proj_out = nn.Conv2d(self.nc, | |
self.nc, | |
kernel_size=1, | |
stride=1, | |
padding=0) | |
self.proj_drop = nn.Dropout(proj_drop, inplace=True) | |
self.attn_drop = nn.Dropout(attn_drop, inplace=True) | |
if self.use_pe: | |
self.rpe_table = nn.Parameter( | |
torch.zeros(self.n_heads, self.kv_h * 2 - 1, | |
self.kv_w * 2 - 1)) | |
trunc_normal_(self.rpe_table, std=0.01) | |
else: | |
self.rpe_table = None | |
def _get_ref_points(self, H_key, W_key, B, dtype, device): | |
ref_y, ref_x = torch.meshgrid( | |
torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype, | |
device=device), | |
torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype, | |
device=device)) | |
ref = torch.stack((ref_y, ref_x), -1) | |
ref[..., 1].div_(W_key).mul_(2).sub_(1) | |
ref[..., 0].div_(H_key).mul_(2).sub_(1) | |
ref = ref[None, ...].expand(B * self.n_groups, -1, -1, | |
-1) # B * g H W 2 | |
return ref | |
def forward(self, x, y): | |
B, C, H, W = x.size() | |
dtype, device = x.dtype, x.device | |
q_off = torch.cat( | |
(x, y), dim=1 | |
).reshape(B, self.n_groups, 2 * self.n_group_channels, H, W).flatten( | |
0, 1 | |
) #einops.rearrange(torch.cat((x,y),dim=1), 'b (g c) h w -> (b g) c h w', g=self.n_groups, c=2*self.n_group_channels) | |
offset = self.conv_offset(q_off) # B * g 2 Hg Wg | |
Hk, Wk = offset.size(2), offset.size(3) | |
n_sample = Hk * Wk | |
if self.offset_range_factor > 0: | |
offset_range = torch.tensor([1.0 / Hk, 1.0 / Wk], | |
device=device).reshape(1, 2, 1, 1) | |
offset = offset.tanh().mul(offset_range).mul( | |
self.offset_range_factor) | |
offset = offset.permute( | |
0, 2, 3, 1) #einops.rearrange(offset, 'b p h w -> b h w p') | |
reference = self._get_ref_points(Hk, Wk, B, dtype, device) | |
if self.offset_range_factor >= 0: | |
pos = offset + reference | |
else: | |
pos = (offset + reference).tanh() | |
q = self.proj_q(y) | |
x_sampled = F.grid_sample( | |
input=x.reshape(B * self.n_groups, self.n_group_channels, H, W), | |
grid=pos[..., (1, 0)], # y, x -> x, y | |
mode='bilinear', | |
align_corners=False) # B * g, Cg, Hg, Wg | |
x_sampled = x_sampled.reshape(B, C, 1, n_sample) | |
q = q.reshape(B * self.n_heads, self.n_head_channels, H * W) | |
k = self.proj_k(x_sampled).reshape(B * self.n_heads, | |
self.n_head_channels, n_sample) | |
v = self.proj_v(x_sampled).reshape(B * self.n_heads, | |
self.n_head_channels, n_sample) | |
attn = torch.einsum('b c m, b c n -> b m n', q, k) # B * h, HW, Ns | |
attn = attn.mul(self.scale) | |
if self.use_pe: | |
rpe_table = self.rpe_table | |
rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1) | |
q_grid = self._get_ref_points(H, W, B, dtype, device) | |
displacement = ( | |
q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - | |
pos.reshape(B * self.n_groups, n_sample, | |
2).unsqueeze(1)).mul(0.5) | |
attn_bias = F.grid_sample(input=rpe_bias.reshape( | |
B * self.n_groups, self.n_group_heads, 2 * H - 1, 2 * W - 1), | |
grid=displacement[..., (1, 0)], | |
mode='bilinear', | |
align_corners=False) | |
attn_bias = attn_bias.reshape(B * self.n_heads, H * W, n_sample) | |
attn = attn + attn_bias | |
attn = F.softmax(attn, dim=2) | |
attn = self.attn_drop(attn) | |
out = torch.einsum('b m n, b c n -> b c m', attn, v) | |
out = out.reshape(B, C, H, W) | |
out = self.proj_drop(self.proj_out(out)) | |
return out, pos.reshape(B, self.n_groups, Hk, Wk, | |
2), reference.reshape(B, self.n_groups, Hk, Wk, | |
2) | |
class FuseModel(nn.Module): | |
def __init__(self, | |
dim, | |
deform_stride=2, | |
stage_idx=2, | |
k_size=[(2, 2), (2, 1), (2, 1), (1, 1)], | |
q_size=(2, 32)): | |
super().__init__() | |
channels = [dim // 2**i for i in range(4)] | |
refine_conv = nn.Conv2d | |
self.deform_stride = deform_stride | |
in_out_ch = [(-1, -2), (-2, -3), (-3, -4), (-4, -4)] | |
self.binary_condition_layer = DAttentionFuse(q_size=q_size, | |
kv_size=q_size, | |
stride=self.deform_stride, | |
n_head_channels=dim // 8, | |
stage_idx=stage_idx) | |
self.binary2refine_linear_norm = nn.ModuleList() | |
for i in range(len(k_size)): | |
self.binary2refine_linear_norm.append( | |
nn.Sequential( | |
Block(dim=channels[in_out_ch[i][0]]), | |
LayerNorm(channels[in_out_ch[i][0]], | |
eps=1e-6, | |
data_format='channels_first'), | |
refine_conv(channels[in_out_ch[i][0]], | |
channels[in_out_ch[i][1]], | |
kernel_size=k_size[i], | |
stride=k_size[i])), # [8, 32] | |
) | |
def forward(self, recog_feat, binary_feats, dec_in=None): | |
multi_feat = [] | |
binary_feat = binary_feats[-1] | |
for i in range(len(self.binary2refine_linear_norm)): | |
binary_feat = self.binary2refine_linear_norm[i](binary_feat) | |
multi_feat.append(binary_feat) | |
binary_feat = binary_feat + binary_feats[0] | |
multi_feat[3] += binary_feats[0] | |
binary_refined_feat, pos, _ = self.binary_condition_layer( | |
recog_feat, binary_feat) | |
binary_refined_feat = binary_refined_feat + binary_feat | |
return binary_refined_feat, binary_feat | |
class CAMEncoder(nn.Module): | |
""" | |
Args: | |
in_chans (int): Number of input image channels. Default: 3 | |
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] | |
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] | |
drop_path_rate (float): Stochastic depth rate. Default: 0. | |
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. | |
""" | |
def __init__(self, | |
in_channels=3, | |
encoder_config={'name': 'ConvNeXtV2'}, | |
nb_classes=71, | |
strides=[(4, 4), (2, 1), (2, 1), (1, 1)], | |
k_size=[(2, 2), (2, 1), (2, 1), (1, 1)], | |
q_size=[2, 32], | |
deform_stride=2, | |
stage_idx=2, | |
use_depthwise_unet=True, | |
use_more_unet=False, | |
binary_loss_type='BanlanceMultiClassCrossEntropyLoss', | |
mid_size=True, | |
d_embedding=384): | |
super().__init__() | |
encoder_name = encoder_config.pop('name') | |
encoder_config['in_channels'] = in_channels | |
self.backbone = eval(encoder_name)(**encoder_config) | |
dim = self.backbone.out_channels | |
self.mid_size = mid_size | |
if self.mid_size: | |
self.enc_downsample = nn.Sequential( | |
nn.Conv2d(dim, dim // 2, kernel_size=1, stride=1), | |
nn.SyncBatchNorm(dim // 2), | |
#nn.ReLU6(inplace=True), | |
nn.Conv2d(dim // 2, | |
dim // 2, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=False, | |
groups=dim // 2), | |
nn.Conv2d(dim // 2, | |
dim // 2, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=False), | |
nn.SyncBatchNorm(dim // 2), | |
) | |
dim = dim // 2 | |
# recognition decoder | |
self.linear_enc2recog = nn.Sequential( | |
nn.Conv2d( | |
dim, | |
dim, | |
kernel_size=1, | |
stride=1, | |
), | |
nn.SyncBatchNorm(dim), | |
#nn.ReLU6(inplace=True), | |
nn.Conv2d(dim, | |
dim, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=False, | |
groups=dim), | |
nn.Conv2d(dim, | |
dim, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=False), | |
nn.SyncBatchNorm(dim), | |
) | |
else: | |
self.linear_enc2recog = nn.Sequential( | |
nn.Conv2d(dim, dim // 2, kernel_size=1, stride=1), | |
nn.SyncBatchNorm(dim // 2), | |
#nn.ReLU6(inplace=True), | |
nn.Conv2d(dim // 2, dim, kernel_size=3, stride=1, padding=1), | |
nn.SyncBatchNorm(dim), | |
) | |
self.linear_norm = nn.Sequential( | |
nn.Linear(dim, d_embedding), | |
nn.LayerNorm(d_embedding, eps=1e-6), | |
) | |
self.out_channels = d_embedding | |
self.binary_decoder = BinaryDecoder( | |
dim, | |
nb_classes, | |
strides, | |
use_depthwise_unet=use_depthwise_unet, | |
use_more_unet=use_more_unet, | |
binary_loss_type=binary_loss_type) | |
self.fuse_model = FuseModel(dim, | |
deform_stride=deform_stride, | |
stage_idx=stage_idx, | |
k_size=k_size, | |
q_size=q_size) | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, (nn.Conv2d, nn.Linear)): | |
trunc_normal_(m.weight, std=.02) | |
if isinstance(m, (nn.Conv2d, nn.Linear)) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
if isinstance(m, nn.ConvTranspose2d): | |
nn.init.kaiming_normal_(m.weight, | |
mode='fan_out', | |
nonlinearity='relu') | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0.) | |
elif isinstance(m, nn.LayerNorm): | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
if m.weight is not None: | |
nn.init.constant_(m.weight, 1.0) | |
elif isinstance(m, nn.SyncBatchNorm): | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
if m.weight is not None: | |
nn.init.constant_(m.weight, 1.0) | |
elif isinstance(m, nn.BatchNorm2d): | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
if m.weight is not None: | |
nn.init.constant_(m.weight, 1.0) | |
def no_weight_decay(self): | |
return {} | |
def forward(self, x): | |
output = {} | |
enc_feat = self.backbone(x) | |
if self.mid_size: | |
enc_feat = self.enc_downsample(enc_feat) | |
output['enc_feat'] = enc_feat | |
# binary mask | |
pred_binary, binary_feats = self.binary_decoder(enc_feat) | |
output['pred_binary'] = pred_binary | |
reg_feat = self.linear_enc2recog(enc_feat) | |
B, C, H, W = reg_feat.shape | |
last_feat, binary_feat = self.fuse_model(reg_feat, binary_feats) | |
dec_in = last_feat.reshape(B, C, H * W).permute(0, 2, 1) | |
dec_in = self.linear_norm(dec_in) | |
output['refined_feat'] = dec_in | |
output['binary_feat'] = binary_feats[-1] | |
return output | |