Spaces:
Running
on
Zero
Running
on
Zero
from collections import OrderedDict | |
import torch | |
from torch.functional import norm | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision.models import vgg16, vgg16_bn | |
import fvcore.nn.weight_init as weight_init | |
from torchvision.models import resnet50 | |
from models.modules import ResBlk, DSLayer, half_DSLayer, CoAttLayer, RefUnet, DBHead | |
from config import Config | |
class GCoNet(nn.Module): | |
def __init__(self, bb_pretrained=True): | |
super(GCoNet, self).__init__() | |
self.config = Config() | |
bb = self.config.bb | |
if bb == 'vgg16': | |
bb_net = list(vgg16(pretrained=bb_pretrained).children())[0] | |
bb_convs = OrderedDict({ | |
'conv1': bb_net[:4], | |
'conv2': bb_net[4:9], | |
'conv3': bb_net[9:16], | |
'conv4': bb_net[16:23], | |
'conv5': bb_net[23:30] | |
}) | |
channel_scale = 1 | |
elif bb == 'resnet50': | |
bb_net = list(resnet50(pretrained=bb_pretrained).children()) | |
bb_convs = OrderedDict({ | |
'conv1': nn.Sequential(*bb_net[0:3]), | |
'conv2': bb_net[4], | |
'conv3': bb_net[5], | |
'conv4': bb_net[6], | |
'conv5': bb_net[7] | |
}) | |
channel_scale = 4 | |
elif bb == 'vgg16bn': | |
bb_net = list(vgg16_bn(pretrained=bb_pretrained).children())[0] | |
bb_convs = OrderedDict({ | |
'conv1': bb_net[:6], | |
'conv2': bb_net[6:13], | |
'conv3': bb_net[13:23], | |
'conv4': bb_net[23:33], | |
'conv5': bb_net[33:43] | |
}) | |
channel_scale = 1 | |
self.bb = nn.Sequential(bb_convs) | |
lateral_channels_in = [512, 512, 256, 128, 64] if 'vgg16' in bb else [2048, 1024, 512, 256, 64] | |
# channel_scale_latlayer = channel_scale // 2 if bb == 'resnet50' else 1 | |
# channel_last = 32 | |
ch_decoder = lateral_channels_in[0]//2//channel_scale | |
self.top_layer = ResBlk(lateral_channels_in[0], ch_decoder) | |
self.enlayer5 = ResBlk(ch_decoder, ch_decoder) | |
if self.config.conv_after_itp: | |
self.dslayer5 = DSLayer(ch_decoder, ch_decoder) | |
self.latlayer5 = ResBlk(lateral_channels_in[1], ch_decoder) if self.config.complex_lateral_connection else nn.Conv2d(lateral_channels_in[1], ch_decoder, 1, 1, 0) | |
ch_decoder //= 2 | |
self.enlayer4 = ResBlk(ch_decoder*2, ch_decoder) | |
if self.config.conv_after_itp: | |
self.dslayer4 = DSLayer(ch_decoder, ch_decoder) | |
self.latlayer4 = ResBlk(lateral_channels_in[2], ch_decoder) if self.config.complex_lateral_connection else nn.Conv2d(lateral_channels_in[2], ch_decoder, 1, 1, 0) | |
if self.config.output_number >= 4: | |
self.conv_out4 = nn.Sequential(nn.Conv2d(ch_decoder, 32, 1, 1, 0), nn.ReLU(inplace=True), nn.Conv2d(32, 1, 1, 1, 0)) | |
ch_decoder //= 2 | |
self.enlayer3 = ResBlk(ch_decoder*2, ch_decoder) | |
if self.config.conv_after_itp: | |
self.dslayer3 = DSLayer(ch_decoder, ch_decoder) | |
self.latlayer3 = ResBlk(lateral_channels_in[3], ch_decoder) if self.config.complex_lateral_connection else nn.Conv2d(lateral_channels_in[3], ch_decoder, 1, 1, 0) | |
if self.config.output_number >= 3: | |
self.conv_out3 = nn.Sequential(nn.Conv2d(ch_decoder, 32, 1, 1, 0), nn.ReLU(inplace=True), nn.Conv2d(32, 1, 1, 1, 0)) | |
ch_decoder //= 2 | |
self.enlayer2 = ResBlk(ch_decoder*2, ch_decoder) | |
if self.config.conv_after_itp: | |
self.dslayer2 = DSLayer(ch_decoder, ch_decoder) | |
self.latlayer2 = ResBlk(lateral_channels_in[4], ch_decoder) if self.config.complex_lateral_connection else nn.Conv2d(lateral_channels_in[4], ch_decoder, 1, 1, 0) | |
if self.config.output_number >= 2: | |
self.conv_out2 = nn.Sequential(nn.Conv2d(ch_decoder, 32, 1, 1, 0), nn.ReLU(inplace=True), nn.Conv2d(32, 1, 1, 1, 0)) | |
self.enlayer1 = ResBlk(ch_decoder, ch_decoder) | |
self.conv_out1 = nn.Sequential(nn.Conv2d(ch_decoder, 1, 1, 1, 0)) | |
if self.config.GAM: | |
self.co_x5 = CoAttLayer(channel_in=lateral_channels_in[0]) | |
if 'contrast' in self.config.loss: | |
self.pred_layer = half_DSLayer(lateral_channels_in[0]) | |
if {'cls', 'cls_mask'} & set(self.config.loss): | |
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |
self.classifier = nn.Linear(lateral_channels_in[0], 291) # DUTS_class has 291 classes | |
for layer in [self.classifier]: | |
weight_init.c2_msra_fill(layer) | |
if self.config.split_mask: | |
self.sgm = nn.Sigmoid() | |
if self.config.refine: | |
self.refiner = nn.Sequential(RefUnet(self.config.refine, 64)) | |
if self.config.split_mask: | |
self.conv_out_mask = nn.Sequential(nn.Conv2d(ch_decoder, 1, 1, 1, 0)) | |
if self.config.db_mask: | |
self.db_mask = DBHead(32) | |
if self.config.db_output_decoder: | |
self.db_output_decoder = DBHead(32) | |
if self.config.cls_mask_operation == 'c': | |
self.conv_cat_mask = nn.Conv2d(4, 3, 1, 1, 0) | |
def forward(self, x): | |
########## Encoder ########## | |
[N, _, H, W] = x.size() | |
x1 = self.bb.conv1(x) | |
x2 = self.bb.conv2(x1) | |
x3 = self.bb.conv3(x2) | |
x4 = self.bb.conv4(x3) | |
x5 = self.bb.conv5(x4) | |
if 'cls' in self.config.loss: | |
_x5 = self.avgpool(x5) | |
_x5 = _x5.view(_x5.size(0), -1) | |
pred_cls = self.classifier(_x5) | |
if self.config.GAM: | |
weighted_x5, neg_x5 = self.co_x5(x5) | |
if 'contrast' in self.config.loss: | |
if self.training: | |
########## contrastive branch ######### | |
cat_x5 = torch.cat([weighted_x5, neg_x5], dim=0) | |
pred_contrast = self.pred_layer(cat_x5) | |
pred_contrast = F.interpolate(pred_contrast, size=(H, W), mode='bilinear', align_corners=True) | |
p5 = self.top_layer(weighted_x5) | |
else: | |
p5 = self.top_layer(x5) | |
########## Decoder ########## | |
scaled_preds = [] | |
p5 = self.enlayer5(p5) | |
p5 = F.interpolate(p5, size=x4.shape[2:], mode='bilinear', align_corners=True) | |
if self.config.conv_after_itp: | |
p5 = self.dslayer5(p5) | |
p4 = p5 + self.latlayer5(x4) | |
p4 = self.enlayer4(p4) | |
p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True) | |
if self.config.conv_after_itp: | |
p4 = self.dslayer4(p4) | |
if self.config.output_number >= 4: | |
p4_out = self.conv_out4(p4) | |
scaled_preds.append(p4_out) | |
p3 = p4 + self.latlayer4(x3) | |
p3 = self.enlayer3(p3) | |
p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True) | |
if self.config.conv_after_itp: | |
p3 = self.dslayer3(p3) | |
if self.config.output_number >= 3: | |
p3_out = self.conv_out3(p3) | |
scaled_preds.append(p3_out) | |
p2 = p3 + self.latlayer3(x2) | |
p2 = self.enlayer2(p2) | |
p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True) | |
if self.config.conv_after_itp: | |
p2 = self.dslayer2(p2) | |
if self.config.output_number >= 2: | |
p2_out = self.conv_out2(p2) | |
scaled_preds.append(p2_out) | |
p1 = p2 + self.latlayer2(x1) | |
p1 = self.enlayer1(p1) | |
p1 = F.interpolate(p1, size=x.shape[2:], mode='bilinear', align_corners=True) | |
if self.config.db_output_decoder: | |
p1_out = self.db_output_decoder(p1) | |
else: | |
p1_out = self.conv_out1(p1) | |
scaled_preds.append(p1_out) | |
if self.config.refine == 1: | |
scaled_preds.append(self.refiner(p1_out)) | |
elif self.config.refine == 4: | |
scaled_preds.append(self.refiner(torch.cat([x, p1_out], dim=1))) | |
if 'cls_mask' in self.config.loss: | |
pred_cls_masks = [] | |
norm_features_mask = [] | |
input_features = [x, x1, x2, x3][:self.config.loss_cls_mask_last_layers] | |
bb_lst = [self.bb.conv1, self.bb.conv2, self.bb.conv3, self.bb.conv4, self.bb.conv5] | |
for idx_out in range(self.config.loss_cls_mask_last_layers): | |
if idx_out: | |
mask_output = scaled_preds[-(idx_out+1+int(bool(self.config.refine)))] | |
else: | |
if self.config.split_mask: | |
if self.config.db_mask: | |
mask_output = self.db_mask(p1) | |
else: | |
mask_output = self.sgm(self.conv_out_mask(p1)) | |
if self.config.cls_mask_operation == 'x': | |
masked_features = input_features[idx_out] * mask_output | |
elif self.config.cls_mask_operation == '+': | |
masked_features = input_features[idx_out] + mask_output | |
elif self.config.cls_mask_operation == 'c': | |
masked_features = self.conv_cat_mask(torch.cat((input_features[idx_out], mask_output), dim=1)) | |
norm_feature_mask = self.avgpool( | |
nn.Sequential(*bb_lst[idx_out:])( | |
masked_features | |
) | |
).view(N, -1) | |
norm_features_mask.append(norm_feature_mask) | |
pred_cls_masks.append( | |
self.classifier( | |
norm_feature_mask | |
) | |
) | |
if self.training: | |
return_values = [] | |
if {'sal', 'cls', 'contrast', 'cls_mask'} == set(self.config.loss): | |
return_values = [scaled_preds, pred_cls, pred_contrast, pred_cls_masks] | |
elif {'sal', 'cls', 'contrast'} == set(self.config.loss): | |
return_values = [scaled_preds, pred_cls, pred_contrast] | |
elif {'sal', 'cls', 'cls_mask'} == set(self.config.loss): | |
return_values = [scaled_preds, pred_cls, pred_cls_masks] | |
elif {'sal', 'cls'} == set(self.config.loss): | |
return_values = [scaled_preds, pred_cls] | |
elif {'sal', 'contrast'} == set(self.config.loss): | |
return_values = [scaled_preds, pred_contrast] | |
elif {'sal', 'cls_mask'} == set(self.config.loss): | |
return_values = [scaled_preds, pred_cls_masks] | |
else: | |
return_values = [scaled_preds] | |
if self.config.lambdas_sal_last['triplet']: | |
norm_features = [] | |
if '_x5' in self.config.triplet: | |
norm_features.append(_x5) | |
if 'mask' in self.config.triplet: | |
norm_features.append(norm_features_mask[0]) | |
return_values.append(norm_features) | |
return return_values | |
else: | |
return scaled_preds | |