ZhengPeng7's picture
Initialization.
7febe9c
raw
history blame
18.3 kB
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import fvcore.nn.weight_init as weight_init
from config import Config
config = Config()
class ResBlk(nn.Module):
def __init__(self, channel_in=64, channel_out=64):
super(ResBlk, self).__init__()
self.conv_in = nn.Conv2d(channel_in, 64, 3, 1, 1)
self.relu_in = nn.ReLU(inplace=True)
self.conv_out = nn.Conv2d(64, channel_out, 3, 1, 1)
if config.use_bn:
self.bn_in = nn.BatchNorm2d(64)
self.bn_out = nn.BatchNorm2d(channel_out)
def forward(self, x):
x = self.conv_in(x)
if config.use_bn:
x = self.bn_in(x)
x = self.relu_in(x)
x = self.conv_out(x)
if config.use_bn:
x = self.bn_out(x)
return x
class DSLayer(nn.Module):
def __init__(self, channel_in=64, channel_out=1, activation_out='relu'):
super(DSLayer, self).__init__()
self.activation_out = activation_out
self.conv1 = nn.Conv2d(channel_in, 64, kernel_size=3, stride=1, padding=1)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.relu2 = nn.ReLU(inplace=True)
if activation_out:
self.pred_conv = nn.Conv2d(64, channel_out, kernel_size=1, stride=1, padding=0)
self.pred_relu = nn.ReLU(inplace=True)
else:
self.pred_conv = nn.Conv2d(64, channel_out, kernel_size=1, stride=1, padding=0)
if config.use_bn:
self.bn1 = nn.BatchNorm2d(64)
self.bn2 = nn.BatchNorm2d(64)
self.pred_bn = nn.BatchNorm2d(channel_out)
def forward(self, x):
x = self.conv1(x)
if config.use_bn:
x = self.bn1(x)
x = self.relu1(x)
x = self.conv2(x)
if config.use_bn:
x = self.bn2(x)
x = self.relu2(x)
x = self.pred_conv(x)
if config.use_bn:
x = self.pred_bn(x)
if self.activation_out:
x = self.pred_relu(x)
return x
class half_DSLayer(nn.Module):
def __init__(self, channel_in=512):
super(half_DSLayer, self).__init__()
self.enlayer = nn.Sequential(
nn.Conv2d(channel_in, int(channel_in//4), kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True)
)
self.predlayer = nn.Sequential(
nn.Conv2d(int(channel_in//4), 1, kernel_size=1, stride=1, padding=0),
)
def forward(self, x):
x = self.enlayer(x)
x = self.predlayer(x)
return x
class CoAttLayer(nn.Module):
def __init__(self, channel_in=512):
super(CoAttLayer, self).__init__()
self.all_attention = eval(Config().relation_module + '(channel_in)')
self.conv_output = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0)
self.conv_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0)
self.fc_transform = nn.Linear(channel_in, channel_in)
for layer in [self.conv_output, self.conv_transform, self.fc_transform]:
weight_init.c2_msra_fill(layer)
def forward(self, x5):
if self.training:
f_begin = 0
f_end = int(x5.shape[0] / 2)
s_begin = f_end
s_end = int(x5.shape[0])
x5_1 = x5[f_begin: f_end]
x5_2 = x5[s_begin: s_end]
x5_new_1 = self.all_attention(x5_1)
x5_new_2 = self.all_attention(x5_2)
x5_1_proto = torch.mean(x5_new_1, (0, 2, 3), True).view(1, -1)
x5_1_proto = x5_1_proto.unsqueeze(-1).unsqueeze(-1) # 1, C, 1, 1
x5_2_proto = torch.mean(x5_new_2, (0, 2, 3), True).view(1, -1)
x5_2_proto = x5_2_proto.unsqueeze(-1).unsqueeze(-1) # 1, C, 1, 1
x5_11 = x5_1 * x5_1_proto
x5_22 = x5_2 * x5_2_proto
weighted_x5 = torch.cat([x5_11, x5_22], dim=0)
x5_12 = x5_1 * x5_2_proto
x5_21 = x5_2 * x5_1_proto
neg_x5 = torch.cat([x5_12, x5_21], dim=0)
else:
x5_new = self.all_attention(x5)
x5_proto = torch.mean(x5_new, (0, 2, 3), True).view(1, -1)
x5_proto = x5_proto.unsqueeze(-1).unsqueeze(-1) # 1, C, 1, 1
weighted_x5 = x5 * x5_proto #* cweight
neg_x5 = None
return weighted_x5, neg_x5
class ICE(nn.Module):
# The Integrity Channel Enhancement (ICE) module
# _X means in X-th column
def __init__(self, channel_in=512):
super(ICE, self).__init__()
self.conv_1 = nn.Conv2d(channel_in, channel_in, 3, 1, 1)
self.conv_2 = nn.Conv1d(channel_in, channel_in, 3, 1, 1)
self.conv_3 = nn.Conv2d(channel_in*3, channel_in, 3, 1, 1)
self.fc_2 = nn.Linear(channel_in, channel_in)
self.fc_3 = nn.Linear(channel_in, channel_in)
def forward(self, x):
x_1, x_2, x_3 = x, x, x
x_1 = x_1 * x_2 * x_3
x_2 = x_1 + x_2 + x_3
x_3 = torch.cat((x_1, x_2, x_3), dim=1)
V = self.conv_1(x_1)
bs, c, h, w = x_2.shape
K = self.conv_2(x_2.view(bs, c, h*w))
Q_prime = self.conv_3(x_3)
Q_prime = torch.norm(Q_prime, dim=(-2, -1)).view(bs, c, 1, 1)
Q_prime = Q_prime.view(bs, -1)
Q_prime = self.fc_3(Q_prime)
Q_prime = torch.softmax(Q_prime, dim=-1)
Q_prime = Q_prime.unsqueeze(1)
Q = torch.matmul(Q_prime, K)
x_2 = torch.nn.functional.cosine_similarity(K, Q, dim=-1)
x_2 = torch.sigmoid(x_2)
x_2 = self.fc_2(x_2)
x_2 = x_2.unsqueeze(-1).unsqueeze(-1)
x_1 = V * x_2 + V
return x_1
class GAM(nn.Module):
def __init__(self, channel_in=512):
super(GAM, self).__init__()
self.query_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0)
self.key_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0)
self.scale = 1.0 / (channel_in ** 0.5)
self.conv6 = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0)
for layer in [self.query_transform, self.key_transform, self.conv6]:
weight_init.c2_msra_fill(layer)
def forward(self, x5):
# x: B,C,H,W
# x_query: B,C,HW
B, C, H5, W5 = x5.size()
x_query = self.query_transform(x5).view(B, C, -1)
# x_query: B,HW,C
x_query = torch.transpose(x_query, 1, 2).contiguous().view(-1, C) # BHW, C
# x_key: B,C,HW
x_key = self.key_transform(x5).view(B, C, -1)
x_key = torch.transpose(x_key, 0, 1).contiguous().view(C, -1) # C, BHW
# W = Q^T K: B,HW,HW
x_w = torch.matmul(x_query, x_key) #* self.scale # BHW, BHW
x_w = x_w.view(B*H5*W5, B, H5*W5)
x_w = torch.max(x_w, -1).values # BHW, B
x_w = x_w.mean(-1)
#x_w = torch.mean(x_w, -1).values # BHW
x_w = x_w.view(B, -1) * self.scale # B, HW
x_w = F.softmax(x_w, dim=-1) # B, HW
x_w = x_w.view(B, H5, W5).unsqueeze(1) # B, 1, H, W
x5 = x5 * x_w
x5 = self.conv6(x5)
return x5
class MHA(nn.Module):
'''
Scaled dot-product attention
'''
def __init__(self, d_model=512, d_k=512, d_v=512, h=8, dropout=.1, channel_in=512):
'''
:param d_model: Output dimensionality of the model
:param d_k: Dimensionality of queries and keys
:param d_v: Dimensionality of values
:param h: Number of heads
'''
super(MHA, self).__init__()
self.query_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0)
self.key_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0)
self.value_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0)
self.fc_q = nn.Linear(d_model, h * d_k)
self.fc_k = nn.Linear(d_model, h * d_k)
self.fc_v = nn.Linear(d_model, h * d_v)
self.fc_o = nn.Linear(h * d_v, d_model)
self.dropout = nn.Dropout(dropout)
self.d_model = d_model
self.d_k = d_k
self.d_v = d_v
self.h = h
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.001)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x, attention_mask=None, attention_weights=None):
'''
Computes
:param queries: Queries (b_s, nq, d_model)
:param keys: Keys (b_s, nk, d_model)
:param values: Values (b_s, nk, d_model)
:param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
:param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
:return:
'''
B, C, H, W = x.size()
queries = self.query_transform(x).view(B, -1, C)
keys = self.query_transform(x).view(B, -1, C)
values = self.query_transform(x).view(B, -1, C)
b_s, nq = queries.shape[:2]
nk = keys.shape[1]
q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)
att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
if attention_weights is not None:
att = att * attention_weights
if attention_mask is not None:
att = att.masked_fill(attention_mask, -np.inf)
att = torch.softmax(att, -1)
att = self.dropout(att)
out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
out = self.fc_o(out).view(B, C, H, W) # (b_s, nq, d_model)
return out
class NonLocal(nn.Module):
def __init__(self, channel_in=512, inter_channels=None, dimension=2, sub_sample=True, bn_layer=True):
super(NonLocal, self).__init__()
assert dimension in [1, 2, 3]
self.dimension = dimension
self.sub_sample = sub_sample
self.channel_in = channel_in
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = channel_in // 2
if self.inter_channels == 0:
self.inter_channels = 1
self.g = nn.Conv2d(self.channel_in, self.inter_channels, 1, 1, 0)
if bn_layer:
self.W = nn.Sequential(
nn.Conv2d(self.inter_channels, self.channel_in, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(self.channel_in)
)
nn.init.constant_(self.W[1].weight, 0)
nn.init.constant_(self.W[1].bias, 0)
else:
self.W = nn.Conv2d(self.inter_channels, self.channel_in, kernel_size=1, stride=1, padding=0)
nn.init.constant_(self.W.weight, 0)
nn.init.constant_(self.W.bias, 0)
self.theta = nn.Conv2d(self.channel_in, self.inter_channels, kernel_size=1, stride=1, padding=0)
self.phi = nn.Conv2d(self.channel_in, self.inter_channels, kernel_size=1, stride=1, padding=0)
if sub_sample:
self.g = nn.Sequential(self.g, nn.MaxPool2d(kernel_size=(2, 2)))
self.phi = nn.Sequential(self.phi, nn.MaxPool2d(kernel_size=(2, 2)))
def forward(self, x, return_nl_map=False):
"""
:param x: (b, c, t, h, w)
:param return_nl_map: if True return z, nl_map, else only return z.
:return:
"""
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
f = torch.matmul(theta_x, phi_x)
f_div_C = F.softmax(f, dim=-1)
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
if return_nl_map:
return z, f_div_C
return z
class DBHead(nn.Module):
def __init__(self, channel_in=32, channel_out=1, k=config.db_k):
super().__init__()
self.k = k
self.binarize = nn.Sequential(
nn.Conv2d(channel_in, channel_in, 3, 1, 1),
*[nn.BatchNorm2d(channel_in), nn.ReLU(inplace=True)] if config.use_bn else nn.ReLU(inplace=True),
nn.Conv2d(channel_in, channel_in, 3, 1, 1),
*[nn.BatchNorm2d(channel_in), nn.ReLU(inplace=True)] if config.use_bn else nn.ReLU(inplace=True),
nn.Conv2d(channel_in, channel_out, 1, 1, 0),
nn.Sigmoid()
)
self.thresh = nn.Sequential(
nn.Conv2d(channel_in, channel_in, 3, padding=1),
*[nn.BatchNorm2d(channel_in), nn.ReLU(inplace=True)] if config.use_bn else nn.ReLU(inplace=True),
nn.Conv2d(channel_in, channel_in, 3, 1, 1),
*[nn.BatchNorm2d(channel_in), nn.ReLU(inplace=True)] if config.use_bn else nn.ReLU(inplace=True),
nn.Conv2d(channel_in, channel_out, 1, 1, 0),
nn.Sigmoid()
)
def forward(self, x):
shrink_maps = self.binarize(x)
threshold_maps = self.thresh(x)
binary_maps = self.step_function(shrink_maps, threshold_maps)
return binary_maps
def step_function(self, x, y):
if config.db_k_alpha != 1:
z = x - y
mask_neg_inv = 1 - 2 * (z < 0)
a = torch.exp(-self.k * (torch.pow(z * mask_neg_inv + 1e-16, 1/config.k_alpha) * mask_neg_inv))
else:
a = torch.exp(-self.k * (x - y))
if torch.isinf(a).any():
a = torch.exp(-50 * (x - y))
return torch.reciprocal(1 + a)
class RefUnet(nn.Module):
# Refinement
def __init__(self, in_ch, inc_ch):
super(RefUnet, self).__init__()
self.conv0 = nn.Conv2d(in_ch, inc_ch, 3, padding=1)
self.conv1 = nn.Conv2d(inc_ch, 64, 3, padding=1)
if config.use_bn:
self.bn1 = nn.BatchNorm2d(64)
self.relu1 = nn.ReLU(inplace=True)
self.pool1 = nn.MaxPool2d(2, 2, ceil_mode=True)
self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
if config.use_bn:
self.bn2 = nn.BatchNorm2d(64)
self.relu2 = nn.ReLU(inplace=True)
self.pool2 = nn.MaxPool2d(2, 2, ceil_mode=True)
self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
if config.use_bn:
self.bn3 = nn.BatchNorm2d(64)
self.relu3 = nn.ReLU(inplace=True)
self.pool3 = nn.MaxPool2d(2, 2, ceil_mode=True)
self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
if config.use_bn:
self.bn4 = nn.BatchNorm2d(64)
self.relu4 = nn.ReLU(inplace=True)
self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True)
#####
self.conv5 = nn.Conv2d(64, 64, 3, padding=1)
if config.use_bn:
self.bn5 = nn.BatchNorm2d(64)
self.relu5 = nn.ReLU(inplace=True)
#####
self.conv_d4 = nn.Conv2d(128, 64, 3, padding=1)
if config.use_bn:
self.bn_d4 = nn.BatchNorm2d(64)
self.relu_d4 = nn.ReLU(inplace=True)
self.conv_d3 = nn.Conv2d(128, 64, 3, padding=1)
if config.use_bn:
self.bn_d3 = nn.BatchNorm2d(64)
self.relu_d3 = nn.ReLU(inplace=True)
self.conv_d2 = nn.Conv2d(128, 64, 3, padding=1)
if config.use_bn:
self.bn_d2 = nn.BatchNorm2d(64)
self.relu_d2 = nn.ReLU(inplace=True)
self.conv_d1 = nn.Conv2d(128, 64, 3, padding=1)
if config.use_bn:
self.bn_d1 = nn.BatchNorm2d(64)
self.relu_d1 = nn.ReLU(inplace=True)
self.conv_d0 = nn.Conv2d(64, 1, 3, padding=1)
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
if config.db_output_refiner:
self.db_output_refiner = DBHead(64)
def forward(self, x):
hx = x
hx = self.conv1(self.conv0(hx))
if config.use_bn:
hx = self.bn1(hx)
hx1 = self.relu1(hx)
hx = self.conv2(self.pool1(hx1))
if config.use_bn:
hx = self.bn2(hx)
hx2 = self.relu2(hx)
hx = self.conv3(self.pool2(hx2))
if config.use_bn:
hx = self.bn3(hx)
hx3 = self.relu3(hx)
hx = self.conv4(self.pool3(hx3))
if config.use_bn:
hx = self.bn4(hx)
hx4 = self.relu4(hx)
hx = self.conv5(self.pool4(hx4))
if config.use_bn:
hx = self.bn5(hx)
hx5 = self.relu5(hx)
hx = self.upscore2(hx5)
d4 = self.conv_d4(torch.cat((hx, hx4), 1))
if config.use_bn:
d4 = self.bn_d4(d4)
d4 = self.relu_d4(d4)
hx = self.upscore2(d4)
d3 = self.conv_d3(torch.cat((hx, hx3), 1))
if config.use_bn:
d3 = self.bn_d3(d3)
d3 = self.relu_d3(d3)
hx = self.upscore2(d3)
d2 = self.conv_d2(torch.cat((hx, hx2), 1))
if config.use_bn:
d2 = self.bn_d2(d2)
d2 = self.relu_d2(d2)
hx = self.upscore2(d2)
d1 = self.conv_d1(torch.cat((hx, hx1), 1))
if config.use_bn:
d1 = self.bn_d1(d1)
d1 = self.relu_d1(d1)
if config.db_output_refiner:
x = self.db_output_refiner(d1)
else:
residual = self.conv_d0(d1)
x = x + residual
return x