import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import fvcore.nn.weight_init as weight_init from config import Config config = Config() class ResBlk(nn.Module): def __init__(self, channel_in=64, channel_out=64): super(ResBlk, self).__init__() self.conv_in = nn.Conv2d(channel_in, 64, 3, 1, 1) self.relu_in = nn.ReLU(inplace=True) self.conv_out = nn.Conv2d(64, channel_out, 3, 1, 1) if config.use_bn: self.bn_in = nn.BatchNorm2d(64) self.bn_out = nn.BatchNorm2d(channel_out) def forward(self, x): x = self.conv_in(x) if config.use_bn: x = self.bn_in(x) x = self.relu_in(x) x = self.conv_out(x) if config.use_bn: x = self.bn_out(x) return x class DSLayer(nn.Module): def __init__(self, channel_in=64, channel_out=1, activation_out='relu'): super(DSLayer, self).__init__() self.activation_out = activation_out self.conv1 = nn.Conv2d(channel_in, 64, kernel_size=3, stride=1, padding=1) self.relu1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) self.relu2 = nn.ReLU(inplace=True) if activation_out: self.pred_conv = nn.Conv2d(64, channel_out, kernel_size=1, stride=1, padding=0) self.pred_relu = nn.ReLU(inplace=True) else: self.pred_conv = nn.Conv2d(64, channel_out, kernel_size=1, stride=1, padding=0) if config.use_bn: self.bn1 = nn.BatchNorm2d(64) self.bn2 = nn.BatchNorm2d(64) self.pred_bn = nn.BatchNorm2d(channel_out) def forward(self, x): x = self.conv1(x) if config.use_bn: x = self.bn1(x) x = self.relu1(x) x = self.conv2(x) if config.use_bn: x = self.bn2(x) x = self.relu2(x) x = self.pred_conv(x) if config.use_bn: x = self.pred_bn(x) if self.activation_out: x = self.pred_relu(x) return x class half_DSLayer(nn.Module): def __init__(self, channel_in=512): super(half_DSLayer, self).__init__() self.enlayer = nn.Sequential( nn.Conv2d(channel_in, int(channel_in//4), kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True) ) self.predlayer = nn.Sequential( nn.Conv2d(int(channel_in//4), 1, kernel_size=1, stride=1, padding=0), ) def forward(self, x): x = self.enlayer(x) x = self.predlayer(x) return x class CoAttLayer(nn.Module): def __init__(self, channel_in=512): super(CoAttLayer, self).__init__() self.all_attention = eval(Config().relation_module + '(channel_in)') self.conv_output = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0) self.conv_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0) self.fc_transform = nn.Linear(channel_in, channel_in) for layer in [self.conv_output, self.conv_transform, self.fc_transform]: weight_init.c2_msra_fill(layer) def forward(self, x5): if self.training: f_begin = 0 f_end = int(x5.shape[0] / 2) s_begin = f_end s_end = int(x5.shape[0]) x5_1 = x5[f_begin: f_end] x5_2 = x5[s_begin: s_end] x5_new_1 = self.all_attention(x5_1) x5_new_2 = self.all_attention(x5_2) x5_1_proto = torch.mean(x5_new_1, (0, 2, 3), True).view(1, -1) x5_1_proto = x5_1_proto.unsqueeze(-1).unsqueeze(-1) # 1, C, 1, 1 x5_2_proto = torch.mean(x5_new_2, (0, 2, 3), True).view(1, -1) x5_2_proto = x5_2_proto.unsqueeze(-1).unsqueeze(-1) # 1, C, 1, 1 x5_11 = x5_1 * x5_1_proto x5_22 = x5_2 * x5_2_proto weighted_x5 = torch.cat([x5_11, x5_22], dim=0) x5_12 = x5_1 * x5_2_proto x5_21 = x5_2 * x5_1_proto neg_x5 = torch.cat([x5_12, x5_21], dim=0) else: x5_new = self.all_attention(x5) x5_proto = torch.mean(x5_new, (0, 2, 3), True).view(1, -1) x5_proto = x5_proto.unsqueeze(-1).unsqueeze(-1) # 1, C, 1, 1 weighted_x5 = x5 * x5_proto #* cweight neg_x5 = None return weighted_x5, neg_x5 class ICE(nn.Module): # The Integrity Channel Enhancement (ICE) module # _X means in X-th column def __init__(self, channel_in=512): super(ICE, self).__init__() self.conv_1 = nn.Conv2d(channel_in, channel_in, 3, 1, 1) self.conv_2 = nn.Conv1d(channel_in, channel_in, 3, 1, 1) self.conv_3 = nn.Conv2d(channel_in*3, channel_in, 3, 1, 1) self.fc_2 = nn.Linear(channel_in, channel_in) self.fc_3 = nn.Linear(channel_in, channel_in) def forward(self, x): x_1, x_2, x_3 = x, x, x x_1 = x_1 * x_2 * x_3 x_2 = x_1 + x_2 + x_3 x_3 = torch.cat((x_1, x_2, x_3), dim=1) V = self.conv_1(x_1) bs, c, h, w = x_2.shape K = self.conv_2(x_2.view(bs, c, h*w)) Q_prime = self.conv_3(x_3) Q_prime = torch.norm(Q_prime, dim=(-2, -1)).view(bs, c, 1, 1) Q_prime = Q_prime.view(bs, -1) Q_prime = self.fc_3(Q_prime) Q_prime = torch.softmax(Q_prime, dim=-1) Q_prime = Q_prime.unsqueeze(1) Q = torch.matmul(Q_prime, K) x_2 = torch.nn.functional.cosine_similarity(K, Q, dim=-1) x_2 = torch.sigmoid(x_2) x_2 = self.fc_2(x_2) x_2 = x_2.unsqueeze(-1).unsqueeze(-1) x_1 = V * x_2 + V return x_1 class GAM(nn.Module): def __init__(self, channel_in=512): super(GAM, self).__init__() self.query_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0) self.key_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0) self.scale = 1.0 / (channel_in ** 0.5) self.conv6 = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0) for layer in [self.query_transform, self.key_transform, self.conv6]: weight_init.c2_msra_fill(layer) def forward(self, x5): # x: B,C,H,W # x_query: B,C,HW B, C, H5, W5 = x5.size() x_query = self.query_transform(x5).view(B, C, -1) # x_query: B,HW,C x_query = torch.transpose(x_query, 1, 2).contiguous().view(-1, C) # BHW, C # x_key: B,C,HW x_key = self.key_transform(x5).view(B, C, -1) x_key = torch.transpose(x_key, 0, 1).contiguous().view(C, -1) # C, BHW # W = Q^T K: B,HW,HW x_w = torch.matmul(x_query, x_key) #* self.scale # BHW, BHW x_w = x_w.view(B*H5*W5, B, H5*W5) x_w = torch.max(x_w, -1).values # BHW, B x_w = x_w.mean(-1) #x_w = torch.mean(x_w, -1).values # BHW x_w = x_w.view(B, -1) * self.scale # B, HW x_w = F.softmax(x_w, dim=-1) # B, HW x_w = x_w.view(B, H5, W5).unsqueeze(1) # B, 1, H, W x5 = x5 * x_w x5 = self.conv6(x5) return x5 class MHA(nn.Module): ''' Scaled dot-product attention ''' def __init__(self, d_model=512, d_k=512, d_v=512, h=8, dropout=.1, channel_in=512): ''' :param d_model: Output dimensionality of the model :param d_k: Dimensionality of queries and keys :param d_v: Dimensionality of values :param h: Number of heads ''' super(MHA, self).__init__() self.query_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0) self.key_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0) self.value_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0) self.fc_q = nn.Linear(d_model, h * d_k) self.fc_k = nn.Linear(d_model, h * d_k) self.fc_v = nn.Linear(d_model, h * d_v) self.fc_o = nn.Linear(h * d_v, d_model) self.dropout = nn.Dropout(dropout) self.d_model = d_model self.d_k = d_k self.d_v = d_v self.h = h self.init_weights() def init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=0.001) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x, attention_mask=None, attention_weights=None): ''' Computes :param queries: Queries (b_s, nq, d_model) :param keys: Keys (b_s, nk, d_model) :param values: Values (b_s, nk, d_model) :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking. :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk). :return: ''' B, C, H, W = x.size() queries = self.query_transform(x).view(B, -1, C) keys = self.query_transform(x).view(B, -1, C) values = self.query_transform(x).view(B, -1, C) b_s, nq = queries.shape[:2] nk = keys.shape[1] q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) if attention_weights is not None: att = att * attention_weights if attention_mask is not None: att = att.masked_fill(attention_mask, -np.inf) att = torch.softmax(att, -1) att = self.dropout(att) out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) out = self.fc_o(out).view(B, C, H, W) # (b_s, nq, d_model) return out class NonLocal(nn.Module): def __init__(self, channel_in=512, inter_channels=None, dimension=2, sub_sample=True, bn_layer=True): super(NonLocal, self).__init__() assert dimension in [1, 2, 3] self.dimension = dimension self.sub_sample = sub_sample self.channel_in = channel_in self.inter_channels = inter_channels if self.inter_channels is None: self.inter_channels = channel_in // 2 if self.inter_channels == 0: self.inter_channels = 1 self.g = nn.Conv2d(self.channel_in, self.inter_channels, 1, 1, 0) if bn_layer: self.W = nn.Sequential( nn.Conv2d(self.inter_channels, self.channel_in, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(self.channel_in) ) nn.init.constant_(self.W[1].weight, 0) nn.init.constant_(self.W[1].bias, 0) else: self.W = nn.Conv2d(self.inter_channels, self.channel_in, kernel_size=1, stride=1, padding=0) nn.init.constant_(self.W.weight, 0) nn.init.constant_(self.W.bias, 0) self.theta = nn.Conv2d(self.channel_in, self.inter_channels, kernel_size=1, stride=1, padding=0) self.phi = nn.Conv2d(self.channel_in, self.inter_channels, kernel_size=1, stride=1, padding=0) if sub_sample: self.g = nn.Sequential(self.g, nn.MaxPool2d(kernel_size=(2, 2))) self.phi = nn.Sequential(self.phi, nn.MaxPool2d(kernel_size=(2, 2))) def forward(self, x, return_nl_map=False): """ :param x: (b, c, t, h, w) :param return_nl_map: if True return z, nl_map, else only return z. :return: """ batch_size = x.size(0) g_x = self.g(x).view(batch_size, self.inter_channels, -1) g_x = g_x.permute(0, 2, 1) theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) theta_x = theta_x.permute(0, 2, 1) phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) f = torch.matmul(theta_x, phi_x) f_div_C = F.softmax(f, dim=-1) y = torch.matmul(f_div_C, g_x) y = y.permute(0, 2, 1).contiguous() y = y.view(batch_size, self.inter_channels, *x.size()[2:]) W_y = self.W(y) z = W_y + x if return_nl_map: return z, f_div_C return z class DBHead(nn.Module): def __init__(self, channel_in=32, channel_out=1, k=config.db_k): super().__init__() self.k = k self.binarize = nn.Sequential( nn.Conv2d(channel_in, channel_in, 3, 1, 1), *[nn.BatchNorm2d(channel_in), nn.ReLU(inplace=True)] if config.use_bn else nn.ReLU(inplace=True), nn.Conv2d(channel_in, channel_in, 3, 1, 1), *[nn.BatchNorm2d(channel_in), nn.ReLU(inplace=True)] if config.use_bn else nn.ReLU(inplace=True), nn.Conv2d(channel_in, channel_out, 1, 1, 0), nn.Sigmoid() ) self.thresh = nn.Sequential( nn.Conv2d(channel_in, channel_in, 3, padding=1), *[nn.BatchNorm2d(channel_in), nn.ReLU(inplace=True)] if config.use_bn else nn.ReLU(inplace=True), nn.Conv2d(channel_in, channel_in, 3, 1, 1), *[nn.BatchNorm2d(channel_in), nn.ReLU(inplace=True)] if config.use_bn else nn.ReLU(inplace=True), nn.Conv2d(channel_in, channel_out, 1, 1, 0), nn.Sigmoid() ) def forward(self, x): shrink_maps = self.binarize(x) threshold_maps = self.thresh(x) binary_maps = self.step_function(shrink_maps, threshold_maps) return binary_maps def step_function(self, x, y): if config.db_k_alpha != 1: z = x - y mask_neg_inv = 1 - 2 * (z < 0) a = torch.exp(-self.k * (torch.pow(z * mask_neg_inv + 1e-16, 1/config.k_alpha) * mask_neg_inv)) else: a = torch.exp(-self.k * (x - y)) if torch.isinf(a).any(): a = torch.exp(-50 * (x - y)) return torch.reciprocal(1 + a) class RefUnet(nn.Module): # Refinement def __init__(self, in_ch, inc_ch): super(RefUnet, self).__init__() self.conv0 = nn.Conv2d(in_ch, inc_ch, 3, padding=1) self.conv1 = nn.Conv2d(inc_ch, 64, 3, padding=1) if config.use_bn: self.bn1 = nn.BatchNorm2d(64) self.relu1 = nn.ReLU(inplace=True) self.pool1 = nn.MaxPool2d(2, 2, ceil_mode=True) self.conv2 = nn.Conv2d(64, 64, 3, padding=1) if config.use_bn: self.bn2 = nn.BatchNorm2d(64) self.relu2 = nn.ReLU(inplace=True) self.pool2 = nn.MaxPool2d(2, 2, ceil_mode=True) self.conv3 = nn.Conv2d(64, 64, 3, padding=1) if config.use_bn: self.bn3 = nn.BatchNorm2d(64) self.relu3 = nn.ReLU(inplace=True) self.pool3 = nn.MaxPool2d(2, 2, ceil_mode=True) self.conv4 = nn.Conv2d(64, 64, 3, padding=1) if config.use_bn: self.bn4 = nn.BatchNorm2d(64) self.relu4 = nn.ReLU(inplace=True) self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True) ##### self.conv5 = nn.Conv2d(64, 64, 3, padding=1) if config.use_bn: self.bn5 = nn.BatchNorm2d(64) self.relu5 = nn.ReLU(inplace=True) ##### self.conv_d4 = nn.Conv2d(128, 64, 3, padding=1) if config.use_bn: self.bn_d4 = nn.BatchNorm2d(64) self.relu_d4 = nn.ReLU(inplace=True) self.conv_d3 = nn.Conv2d(128, 64, 3, padding=1) if config.use_bn: self.bn_d3 = nn.BatchNorm2d(64) self.relu_d3 = nn.ReLU(inplace=True) self.conv_d2 = nn.Conv2d(128, 64, 3, padding=1) if config.use_bn: self.bn_d2 = nn.BatchNorm2d(64) self.relu_d2 = nn.ReLU(inplace=True) self.conv_d1 = nn.Conv2d(128, 64, 3, padding=1) if config.use_bn: self.bn_d1 = nn.BatchNorm2d(64) self.relu_d1 = nn.ReLU(inplace=True) self.conv_d0 = nn.Conv2d(64, 1, 3, padding=1) self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) if config.db_output_refiner: self.db_output_refiner = DBHead(64) def forward(self, x): hx = x hx = self.conv1(self.conv0(hx)) if config.use_bn: hx = self.bn1(hx) hx1 = self.relu1(hx) hx = self.conv2(self.pool1(hx1)) if config.use_bn: hx = self.bn2(hx) hx2 = self.relu2(hx) hx = self.conv3(self.pool2(hx2)) if config.use_bn: hx = self.bn3(hx) hx3 = self.relu3(hx) hx = self.conv4(self.pool3(hx3)) if config.use_bn: hx = self.bn4(hx) hx4 = self.relu4(hx) hx = self.conv5(self.pool4(hx4)) if config.use_bn: hx = self.bn5(hx) hx5 = self.relu5(hx) hx = self.upscore2(hx5) d4 = self.conv_d4(torch.cat((hx, hx4), 1)) if config.use_bn: d4 = self.bn_d4(d4) d4 = self.relu_d4(d4) hx = self.upscore2(d4) d3 = self.conv_d3(torch.cat((hx, hx3), 1)) if config.use_bn: d3 = self.bn_d3(d3) d3 = self.relu_d3(d3) hx = self.upscore2(d3) d2 = self.conv_d2(torch.cat((hx, hx2), 1)) if config.use_bn: d2 = self.bn_d2(d2) d2 = self.relu_d2(d2) hx = self.upscore2(d2) d1 = self.conv_d1(torch.cat((hx, hx1), 1)) if config.use_bn: d1 = self.bn_d1(d1) d1 = self.relu_d1(d1) if config.db_output_refiner: x = self.db_output_refiner(d1) else: residual = self.conv_d0(d1) x = x + residual return x