ZhengPeng7's picture
Initialization.
7febe9c
raw
history blame
11 kB
from collections import OrderedDict
import torch
from torch.functional import norm
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg16, vgg16_bn
import fvcore.nn.weight_init as weight_init
from torchvision.models import resnet50
from models.modules import ResBlk, DSLayer, half_DSLayer, CoAttLayer, RefUnet, DBHead
from config import Config
class GCoNet(nn.Module):
def __init__(self, bb_pretrained=True):
super(GCoNet, self).__init__()
self.config = Config()
bb = self.config.bb
if bb == 'vgg16':
bb_net = list(vgg16(pretrained=bb_pretrained).children())[0]
bb_convs = OrderedDict({
'conv1': bb_net[:4],
'conv2': bb_net[4:9],
'conv3': bb_net[9:16],
'conv4': bb_net[16:23],
'conv5': bb_net[23:30]
})
channel_scale = 1
elif bb == 'resnet50':
bb_net = list(resnet50(pretrained=bb_pretrained).children())
bb_convs = OrderedDict({
'conv1': nn.Sequential(*bb_net[0:3]),
'conv2': bb_net[4],
'conv3': bb_net[5],
'conv4': bb_net[6],
'conv5': bb_net[7]
})
channel_scale = 4
elif bb == 'vgg16bn':
bb_net = list(vgg16_bn(pretrained=bb_pretrained).children())[0]
bb_convs = OrderedDict({
'conv1': bb_net[:6],
'conv2': bb_net[6:13],
'conv3': bb_net[13:23],
'conv4': bb_net[23:33],
'conv5': bb_net[33:43]
})
channel_scale = 1
self.bb = nn.Sequential(bb_convs)
lateral_channels_in = [512, 512, 256, 128, 64] if 'vgg16' in bb else [2048, 1024, 512, 256, 64]
# channel_scale_latlayer = channel_scale // 2 if bb == 'resnet50' else 1
# channel_last = 32
ch_decoder = lateral_channels_in[0]//2//channel_scale
self.top_layer = ResBlk(lateral_channels_in[0], ch_decoder)
self.enlayer5 = ResBlk(ch_decoder, ch_decoder)
if self.config.conv_after_itp:
self.dslayer5 = DSLayer(ch_decoder, ch_decoder)
self.latlayer5 = ResBlk(lateral_channels_in[1], ch_decoder) if self.config.complex_lateral_connection else nn.Conv2d(lateral_channels_in[1], ch_decoder, 1, 1, 0)
ch_decoder //= 2
self.enlayer4 = ResBlk(ch_decoder*2, ch_decoder)
if self.config.conv_after_itp:
self.dslayer4 = DSLayer(ch_decoder, ch_decoder)
self.latlayer4 = ResBlk(lateral_channels_in[2], ch_decoder) if self.config.complex_lateral_connection else nn.Conv2d(lateral_channels_in[2], ch_decoder, 1, 1, 0)
if self.config.output_number >= 4:
self.conv_out4 = nn.Sequential(nn.Conv2d(ch_decoder, 32, 1, 1, 0), nn.ReLU(inplace=True), nn.Conv2d(32, 1, 1, 1, 0))
ch_decoder //= 2
self.enlayer3 = ResBlk(ch_decoder*2, ch_decoder)
if self.config.conv_after_itp:
self.dslayer3 = DSLayer(ch_decoder, ch_decoder)
self.latlayer3 = ResBlk(lateral_channels_in[3], ch_decoder) if self.config.complex_lateral_connection else nn.Conv2d(lateral_channels_in[3], ch_decoder, 1, 1, 0)
if self.config.output_number >= 3:
self.conv_out3 = nn.Sequential(nn.Conv2d(ch_decoder, 32, 1, 1, 0), nn.ReLU(inplace=True), nn.Conv2d(32, 1, 1, 1, 0))
ch_decoder //= 2
self.enlayer2 = ResBlk(ch_decoder*2, ch_decoder)
if self.config.conv_after_itp:
self.dslayer2 = DSLayer(ch_decoder, ch_decoder)
self.latlayer2 = ResBlk(lateral_channels_in[4], ch_decoder) if self.config.complex_lateral_connection else nn.Conv2d(lateral_channels_in[4], ch_decoder, 1, 1, 0)
if self.config.output_number >= 2:
self.conv_out2 = nn.Sequential(nn.Conv2d(ch_decoder, 32, 1, 1, 0), nn.ReLU(inplace=True), nn.Conv2d(32, 1, 1, 1, 0))
self.enlayer1 = ResBlk(ch_decoder, ch_decoder)
self.conv_out1 = nn.Sequential(nn.Conv2d(ch_decoder, 1, 1, 1, 0))
if self.config.GAM:
self.co_x5 = CoAttLayer(channel_in=lateral_channels_in[0])
if 'contrast' in self.config.loss:
self.pred_layer = half_DSLayer(lateral_channels_in[0])
if {'cls', 'cls_mask'} & set(self.config.loss):
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Linear(lateral_channels_in[0], 291) # DUTS_class has 291 classes
for layer in [self.classifier]:
weight_init.c2_msra_fill(layer)
if self.config.split_mask:
self.sgm = nn.Sigmoid()
if self.config.refine:
self.refiner = nn.Sequential(RefUnet(self.config.refine, 64))
if self.config.split_mask:
self.conv_out_mask = nn.Sequential(nn.Conv2d(ch_decoder, 1, 1, 1, 0))
if self.config.db_mask:
self.db_mask = DBHead(32)
if self.config.db_output_decoder:
self.db_output_decoder = DBHead(32)
if self.config.cls_mask_operation == 'c':
self.conv_cat_mask = nn.Conv2d(4, 3, 1, 1, 0)
def forward(self, x):
########## Encoder ##########
[N, _, H, W] = x.size()
x1 = self.bb.conv1(x)
x2 = self.bb.conv2(x1)
x3 = self.bb.conv3(x2)
x4 = self.bb.conv4(x3)
x5 = self.bb.conv5(x4)
if 'cls' in self.config.loss:
_x5 = self.avgpool(x5)
_x5 = _x5.view(_x5.size(0), -1)
pred_cls = self.classifier(_x5)
if self.config.GAM:
weighted_x5, neg_x5 = self.co_x5(x5)
if 'contrast' in self.config.loss:
if self.training:
########## contrastive branch #########
cat_x5 = torch.cat([weighted_x5, neg_x5], dim=0)
pred_contrast = self.pred_layer(cat_x5)
pred_contrast = F.interpolate(pred_contrast, size=(H, W), mode='bilinear', align_corners=True)
p5 = self.top_layer(weighted_x5)
else:
p5 = self.top_layer(x5)
########## Decoder ##########
scaled_preds = []
p5 = self.enlayer5(p5)
p5 = F.interpolate(p5, size=x4.shape[2:], mode='bilinear', align_corners=True)
if self.config.conv_after_itp:
p5 = self.dslayer5(p5)
p4 = p5 + self.latlayer5(x4)
p4 = self.enlayer4(p4)
p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
if self.config.conv_after_itp:
p4 = self.dslayer4(p4)
if self.config.output_number >= 4:
p4_out = self.conv_out4(p4)
scaled_preds.append(p4_out)
p3 = p4 + self.latlayer4(x3)
p3 = self.enlayer3(p3)
p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
if self.config.conv_after_itp:
p3 = self.dslayer3(p3)
if self.config.output_number >= 3:
p3_out = self.conv_out3(p3)
scaled_preds.append(p3_out)
p2 = p3 + self.latlayer3(x2)
p2 = self.enlayer2(p2)
p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True)
if self.config.conv_after_itp:
p2 = self.dslayer2(p2)
if self.config.output_number >= 2:
p2_out = self.conv_out2(p2)
scaled_preds.append(p2_out)
p1 = p2 + self.latlayer2(x1)
p1 = self.enlayer1(p1)
p1 = F.interpolate(p1, size=x.shape[2:], mode='bilinear', align_corners=True)
if self.config.db_output_decoder:
p1_out = self.db_output_decoder(p1)
else:
p1_out = self.conv_out1(p1)
scaled_preds.append(p1_out)
if self.config.refine == 1:
scaled_preds.append(self.refiner(p1_out))
elif self.config.refine == 4:
scaled_preds.append(self.refiner(torch.cat([x, p1_out], dim=1)))
if 'cls_mask' in self.config.loss:
pred_cls_masks = []
norm_features_mask = []
input_features = [x, x1, x2, x3][:self.config.loss_cls_mask_last_layers]
bb_lst = [self.bb.conv1, self.bb.conv2, self.bb.conv3, self.bb.conv4, self.bb.conv5]
for idx_out in range(self.config.loss_cls_mask_last_layers):
if idx_out:
mask_output = scaled_preds[-(idx_out+1+int(bool(self.config.refine)))]
else:
if self.config.split_mask:
if self.config.db_mask:
mask_output = self.db_mask(p1)
else:
mask_output = self.sgm(self.conv_out_mask(p1))
if self.config.cls_mask_operation == 'x':
masked_features = input_features[idx_out] * mask_output
elif self.config.cls_mask_operation == '+':
masked_features = input_features[idx_out] + mask_output
elif self.config.cls_mask_operation == 'c':
masked_features = self.conv_cat_mask(torch.cat((input_features[idx_out], mask_output), dim=1))
norm_feature_mask = self.avgpool(
nn.Sequential(*bb_lst[idx_out:])(
masked_features
)
).view(N, -1)
norm_features_mask.append(norm_feature_mask)
pred_cls_masks.append(
self.classifier(
norm_feature_mask
)
)
if self.training:
return_values = []
if {'sal', 'cls', 'contrast', 'cls_mask'} == set(self.config.loss):
return_values = [scaled_preds, pred_cls, pred_contrast, pred_cls_masks]
elif {'sal', 'cls', 'contrast'} == set(self.config.loss):
return_values = [scaled_preds, pred_cls, pred_contrast]
elif {'sal', 'cls', 'cls_mask'} == set(self.config.loss):
return_values = [scaled_preds, pred_cls, pred_cls_masks]
elif {'sal', 'cls'} == set(self.config.loss):
return_values = [scaled_preds, pred_cls]
elif {'sal', 'contrast'} == set(self.config.loss):
return_values = [scaled_preds, pred_contrast]
elif {'sal', 'cls_mask'} == set(self.config.loss):
return_values = [scaled_preds, pred_cls_masks]
else:
return_values = [scaled_preds]
if self.config.lambdas_sal_last['triplet']:
norm_features = []
if '_x5' in self.config.triplet:
norm_features.append(_x5)
if 'mask' in self.config.triplet:
norm_features.append(norm_features_mask[0])
return_values.append(norm_features)
return return_values
else:
return scaled_preds