Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def freeze_weights(module): | |
for param in module.parameters(): | |
param.requires_grad = False | |
def l1_regularize(module): | |
reg_loss = 0. | |
for key, param in module.reg_params.items(): | |
if "weight" in key and param.requires_grad: | |
reg_loss += torch.sum(torch.abs(param)) | |
return reg_loss | |
class SeparableConv2d(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False): | |
super(SeparableConv2d, self).__init__() | |
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, | |
groups=in_channels, bias=bias) | |
self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias) | |
def forward(self, x): | |
x = self.conv1(x) | |
x = self.pointwise(x) | |
return x | |
class Block(nn.Module): | |
def __init__(self, in_channels, out_channels, reps, strides=1, | |
start_with_relu=True, grow_first=True, with_bn=True): | |
super(Block, self).__init__() | |
self.with_bn = with_bn | |
if out_channels != in_channels or strides != 1: | |
self.skip = nn.Conv2d(in_channels, out_channels, 1, stride=strides, bias=False) | |
if with_bn: | |
self.skipbn = nn.BatchNorm2d(out_channels) | |
else: | |
self.skip = None | |
rep = [] | |
for i in range(reps): | |
if grow_first: | |
inc = in_channels if i == 0 else out_channels | |
outc = out_channels | |
else: | |
inc = in_channels | |
outc = in_channels if i < (reps - 1) else out_channels | |
rep.append(nn.ReLU(inplace=True)) | |
rep.append(SeparableConv2d(inc, outc, 3, stride=1, padding=1)) | |
if with_bn: | |
rep.append(nn.BatchNorm2d(outc)) | |
if not start_with_relu: | |
rep = rep[1:] | |
else: | |
rep[0] = nn.ReLU(inplace=False) | |
if strides != 1: | |
rep.append(nn.MaxPool2d(3, strides, 1)) | |
self.rep = nn.Sequential(*rep) | |
def forward(self, inp): | |
x = self.rep(inp) | |
if self.skip is not None: | |
skip = self.skip(inp) | |
if self.with_bn: | |
skip = self.skipbn(skip) | |
else: | |
skip = inp | |
x += skip | |
return x | |
class GraphReasoning(nn.Module): | |
""" Graph Reasoning Module for information aggregation. """ | |
def __init__(self, va_in, va_out, vb_in, vb_out, vc_in, vc_out, spatial_ratio, drop_rate): | |
super(GraphReasoning, self).__init__() | |
self.ratio = spatial_ratio | |
self.va_embedding = nn.Sequential( | |
nn.Conv2d(va_in, va_out, 1, bias=False), | |
nn.ReLU(True), | |
nn.Conv2d(va_out, va_out, 1, bias=False), | |
) | |
self.va_gated_b = nn.Sequential( | |
nn.Conv2d(va_in, va_out, 1, bias=False), | |
nn.Sigmoid() | |
) | |
self.va_gated_c = nn.Sequential( | |
nn.Conv2d(va_in, va_out, 1, bias=False), | |
nn.Sigmoid() | |
) | |
self.vb_embedding = nn.Sequential( | |
nn.Linear(vb_in, vb_out, bias=False), | |
nn.ReLU(True), | |
nn.Linear(vb_out, vb_out, bias=False), | |
) | |
self.vc_embedding = nn.Sequential( | |
nn.Linear(vc_in, vc_out, bias=False), | |
nn.ReLU(True), | |
nn.Linear(vc_out, vc_out, bias=False), | |
) | |
self.unfold_b = nn.Unfold(kernel_size=spatial_ratio[0], stride=spatial_ratio[0]) | |
self.unfold_c = nn.Unfold(kernel_size=spatial_ratio[1], stride=spatial_ratio[1]) | |
self.reweight_ab = nn.Sequential( | |
nn.Linear(va_out + vb_out, 1, bias=False), | |
nn.ReLU(True), | |
nn.Softmax(dim=1) | |
) | |
self.reweight_ac = nn.Sequential( | |
nn.Linear(va_out + vc_out, 1, bias=False), | |
nn.ReLU(True), | |
nn.Softmax(dim=1) | |
) | |
self.reproject = nn.Sequential( | |
nn.Conv2d(va_out + vb_out + vc_out, va_in, kernel_size=1, bias=False), | |
nn.ReLU(True), | |
nn.Conv2d(va_in, va_in, kernel_size=1, bias=False), | |
nn.Dropout(drop_rate) if drop_rate is not None else nn.Identity(), | |
) | |
def forward(self, vert_a, vert_b, vert_c): | |
emb_vert_a = self.va_embedding(vert_a) | |
emb_vert_a = emb_vert_a.reshape([emb_vert_a.shape[0], emb_vert_a.shape[1], -1]) | |
gate_vert_b = 1 - self.va_gated_b(vert_a) | |
gate_vert_b = gate_vert_b.reshape(*emb_vert_a.shape) | |
gate_vert_c = 1 - self.va_gated_c(vert_a) | |
gate_vert_c = gate_vert_c.reshape(*emb_vert_a.shape) | |
vert_b = self.unfold_b(vert_b).reshape( | |
[vert_b.shape[0], vert_b.shape[1], self.ratio[0] * self.ratio[0], -1]) | |
vert_b = vert_b.permute([0, 2, 3, 1]) | |
emb_vert_b = self.vb_embedding(vert_b) | |
vert_c = self.unfold_c(vert_c).reshape( | |
[vert_c.shape[0], vert_c.shape[1], self.ratio[1] * self.ratio[1], -1]) | |
vert_c = vert_c.permute([0, 2, 3, 1]) | |
emb_vert_c = self.vc_embedding(vert_c) | |
agg_vb = list() | |
agg_vc = list() | |
for j in range(emb_vert_a.shape[-1]): | |
# ab propagating | |
emb_v_a = torch.stack([emb_vert_a[:, :, j]] * (self.ratio[0] ** 2), dim=1) | |
emb_v_b = emb_vert_b[:, :, j, :] | |
emb_v_ab = torch.cat([emb_v_a, emb_v_b], dim=-1) | |
w = self.reweight_ab(emb_v_ab) | |
agg_vb.append(torch.bmm(emb_v_b.transpose(1, 2), w).squeeze() * gate_vert_b[:, :, j]) | |
# ac propagating | |
emb_v_a = torch.stack([emb_vert_a[:, :, j]] * (self.ratio[1] ** 2), dim=1) | |
emb_v_c = emb_vert_c[:, :, j, :] | |
emb_v_ac = torch.cat([emb_v_a, emb_v_c], dim=-1) | |
w = self.reweight_ac(emb_v_ac) | |
agg_vc.append(torch.bmm(emb_v_c.transpose(1, 2), w).squeeze() * gate_vert_c[:, :, j]) | |
agg_vert_b = torch.stack(agg_vb, dim=-1) | |
agg_vert_c = torch.stack(agg_vc, dim=-1) | |
agg_vert_bc = torch.cat([agg_vert_b, agg_vert_c], dim=1) | |
agg_vert_abc = torch.cat([agg_vert_bc, emb_vert_a], dim=1) | |
agg_vert_abc = torch.sigmoid(agg_vert_abc) | |
agg_vert_abc = agg_vert_abc.reshape(vert_a.shape[0], -1, vert_a.shape[2], vert_a.shape[3]) | |
return self.reproject(agg_vert_abc) | |
class GuidedAttention(nn.Module): | |
""" Reconstruction Guided Attention. """ | |
def __init__(self, depth=728, drop_rate=0.2): | |
super(GuidedAttention, self).__init__() | |
self.depth = depth | |
self.gated = nn.Sequential( | |
nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False), | |
nn.ReLU(True), | |
nn.Conv2d(3, 1, 1, bias=False), | |
nn.Sigmoid() | |
) | |
self.h = nn.Sequential( | |
nn.Conv2d(depth, depth, 1, 1, bias=False), | |
nn.BatchNorm2d(depth), | |
nn.ReLU(True), | |
) | |
self.dropout = nn.Dropout(drop_rate) | |
def forward(self, x, pred_x, embedding): | |
residual_full = torch.abs(x - pred_x) | |
residual_x = F.interpolate(residual_full, size=embedding.shape[-2:], | |
mode='bilinear', align_corners=True) | |
res_map = self.gated(residual_x) | |
return res_map * self.h(embedding) + self.dropout(embedding) | |