import torch import torch.nn as nn import torch.nn.functional as F def freeze_weights(module): for param in module.parameters(): param.requires_grad = False def l1_regularize(module): reg_loss = 0. for key, param in module.reg_params.items(): if "weight" in key and param.requires_grad: reg_loss += torch.sum(torch.abs(param)) return reg_loss class SeparableConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False): super(SeparableConv2d, self).__init__() self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=bias) self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias) def forward(self, x): x = self.conv1(x) x = self.pointwise(x) return x class Block(nn.Module): def __init__(self, in_channels, out_channels, reps, strides=1, start_with_relu=True, grow_first=True, with_bn=True): super(Block, self).__init__() self.with_bn = with_bn if out_channels != in_channels or strides != 1: self.skip = nn.Conv2d(in_channels, out_channels, 1, stride=strides, bias=False) if with_bn: self.skipbn = nn.BatchNorm2d(out_channels) else: self.skip = None rep = [] for i in range(reps): if grow_first: inc = in_channels if i == 0 else out_channels outc = out_channels else: inc = in_channels outc = in_channels if i < (reps - 1) else out_channels rep.append(nn.ReLU(inplace=True)) rep.append(SeparableConv2d(inc, outc, 3, stride=1, padding=1)) if with_bn: rep.append(nn.BatchNorm2d(outc)) if not start_with_relu: rep = rep[1:] else: rep[0] = nn.ReLU(inplace=False) if strides != 1: rep.append(nn.MaxPool2d(3, strides, 1)) self.rep = nn.Sequential(*rep) def forward(self, inp): x = self.rep(inp) if self.skip is not None: skip = self.skip(inp) if self.with_bn: skip = self.skipbn(skip) else: skip = inp x += skip return x class GraphReasoning(nn.Module): """ Graph Reasoning Module for information aggregation. """ def __init__(self, va_in, va_out, vb_in, vb_out, vc_in, vc_out, spatial_ratio, drop_rate): super(GraphReasoning, self).__init__() self.ratio = spatial_ratio self.va_embedding = nn.Sequential( nn.Conv2d(va_in, va_out, 1, bias=False), nn.ReLU(True), nn.Conv2d(va_out, va_out, 1, bias=False), ) self.va_gated_b = nn.Sequential( nn.Conv2d(va_in, va_out, 1, bias=False), nn.Sigmoid() ) self.va_gated_c = nn.Sequential( nn.Conv2d(va_in, va_out, 1, bias=False), nn.Sigmoid() ) self.vb_embedding = nn.Sequential( nn.Linear(vb_in, vb_out, bias=False), nn.ReLU(True), nn.Linear(vb_out, vb_out, bias=False), ) self.vc_embedding = nn.Sequential( nn.Linear(vc_in, vc_out, bias=False), nn.ReLU(True), nn.Linear(vc_out, vc_out, bias=False), ) self.unfold_b = nn.Unfold(kernel_size=spatial_ratio[0], stride=spatial_ratio[0]) self.unfold_c = nn.Unfold(kernel_size=spatial_ratio[1], stride=spatial_ratio[1]) self.reweight_ab = nn.Sequential( nn.Linear(va_out + vb_out, 1, bias=False), nn.ReLU(True), nn.Softmax(dim=1) ) self.reweight_ac = nn.Sequential( nn.Linear(va_out + vc_out, 1, bias=False), nn.ReLU(True), nn.Softmax(dim=1) ) self.reproject = nn.Sequential( nn.Conv2d(va_out + vb_out + vc_out, va_in, kernel_size=1, bias=False), nn.ReLU(True), nn.Conv2d(va_in, va_in, kernel_size=1, bias=False), nn.Dropout(drop_rate) if drop_rate is not None else nn.Identity(), ) def forward(self, vert_a, vert_b, vert_c): emb_vert_a = self.va_embedding(vert_a) emb_vert_a = emb_vert_a.reshape([emb_vert_a.shape[0], emb_vert_a.shape[1], -1]) gate_vert_b = 1 - self.va_gated_b(vert_a) gate_vert_b = gate_vert_b.reshape(*emb_vert_a.shape) gate_vert_c = 1 - self.va_gated_c(vert_a) gate_vert_c = gate_vert_c.reshape(*emb_vert_a.shape) vert_b = self.unfold_b(vert_b).reshape( [vert_b.shape[0], vert_b.shape[1], self.ratio[0] * self.ratio[0], -1]) vert_b = vert_b.permute([0, 2, 3, 1]) emb_vert_b = self.vb_embedding(vert_b) vert_c = self.unfold_c(vert_c).reshape( [vert_c.shape[0], vert_c.shape[1], self.ratio[1] * self.ratio[1], -1]) vert_c = vert_c.permute([0, 2, 3, 1]) emb_vert_c = self.vc_embedding(vert_c) agg_vb = list() agg_vc = list() for j in range(emb_vert_a.shape[-1]): # ab propagating emb_v_a = torch.stack([emb_vert_a[:, :, j]] * (self.ratio[0] ** 2), dim=1) emb_v_b = emb_vert_b[:, :, j, :] emb_v_ab = torch.cat([emb_v_a, emb_v_b], dim=-1) w = self.reweight_ab(emb_v_ab) agg_vb.append(torch.bmm(emb_v_b.transpose(1, 2), w).squeeze() * gate_vert_b[:, :, j]) # ac propagating emb_v_a = torch.stack([emb_vert_a[:, :, j]] * (self.ratio[1] ** 2), dim=1) emb_v_c = emb_vert_c[:, :, j, :] emb_v_ac = torch.cat([emb_v_a, emb_v_c], dim=-1) w = self.reweight_ac(emb_v_ac) agg_vc.append(torch.bmm(emb_v_c.transpose(1, 2), w).squeeze() * gate_vert_c[:, :, j]) agg_vert_b = torch.stack(agg_vb, dim=-1) agg_vert_c = torch.stack(agg_vc, dim=-1) agg_vert_bc = torch.cat([agg_vert_b, agg_vert_c], dim=1) agg_vert_abc = torch.cat([agg_vert_bc, emb_vert_a], dim=1) agg_vert_abc = torch.sigmoid(agg_vert_abc) agg_vert_abc = agg_vert_abc.reshape(vert_a.shape[0], -1, vert_a.shape[2], vert_a.shape[3]) return self.reproject(agg_vert_abc) class GuidedAttention(nn.Module): """ Reconstruction Guided Attention. """ def __init__(self, depth=728, drop_rate=0.2): super(GuidedAttention, self).__init__() self.depth = depth self.gated = nn.Sequential( nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False), nn.ReLU(True), nn.Conv2d(3, 1, 1, bias=False), nn.Sigmoid() ) self.h = nn.Sequential( nn.Conv2d(depth, depth, 1, 1, bias=False), nn.BatchNorm2d(depth), nn.ReLU(True), ) self.dropout = nn.Dropout(drop_rate) def forward(self, x, pred_x, embedding): residual_full = torch.abs(x - pred_x) residual_x = F.interpolate(residual_full, size=embedding.shape[-2:], mode='bilinear', align_corners=True) res_map = self.gated(residual_x) return res_map * self.h(embedding) + self.dropout(embedding)