|
|
|
|
|
""" |
|
@Author : Qingping Zheng |
|
@Contact : [email protected] |
|
@File : ddgcn.py |
|
@Time : 10/01/21 00:00 PM |
|
@Desc : |
|
@License : Licensed under the Apache License, Version 2.0 (the "License"); |
|
@Copyright : Copyright 2022 The Authors. All Rights Reserved. |
|
""" |
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import torch.nn as nn |
|
|
|
from inplace_abn import InPlaceABNSync |
|
|
|
|
|
class SpatialGCN(nn.Module): |
|
def __init__(self, plane, abn=InPlaceABNSync): |
|
super(SpatialGCN, self).__init__() |
|
inter_plane = plane // 2 |
|
self.node_k = nn.Conv2d(plane, inter_plane, kernel_size=1) |
|
self.node_v = nn.Conv2d(plane, inter_plane, kernel_size=1) |
|
self.node_q = nn.Conv2d(plane, inter_plane, kernel_size=1) |
|
|
|
self.conv_wg = nn.Conv1d(inter_plane, inter_plane, kernel_size=1, bias=False) |
|
self.bn_wg = nn.BatchNorm1d(inter_plane) |
|
self.softmax = nn.Softmax(dim=2) |
|
|
|
self.out = nn.Sequential(nn.Conv2d(inter_plane, plane, kernel_size=1), |
|
abn(plane)) |
|
|
|
self.gamma = nn.Parameter(torch.zeros(1)) |
|
|
|
def forward(self, x): |
|
|
|
node_k = self.node_k(x) |
|
node_v = self.node_v(x) |
|
node_q = self.node_q(x) |
|
b,c,h,w = node_k.size() |
|
node_k = node_k.view(b, c, -1).permute(0, 2, 1) |
|
node_q = node_q.view(b, c, -1) |
|
node_v = node_v.view(b, c, -1).permute(0, 2, 1) |
|
|
|
|
|
|
|
AV = torch.bmm(node_q,node_v) |
|
AV = self.softmax(AV) |
|
AV = torch.bmm(node_k, AV) |
|
AV = AV.transpose(1, 2).contiguous() |
|
AVW = self.conv_wg(AV) |
|
AVW = self.bn_wg(AVW) |
|
AVW = AVW.view(b, c, h, -1) |
|
|
|
out = self.gamma * self.out(AVW) + x |
|
return out |
|
|
|
|
|
class DDualGCN(nn.Module): |
|
""" |
|
Feature GCN with coordinate GCN |
|
""" |
|
def __init__(self, planes, abn=InPlaceABNSync, ratio=4): |
|
super(DDualGCN, self).__init__() |
|
|
|
self.phi = nn.Conv2d(planes, planes // ratio * 2, kernel_size=1, bias=False) |
|
self.bn_phi = abn(planes // ratio * 2) |
|
self.theta = nn.Conv2d(planes, planes // ratio, kernel_size=1, bias=False) |
|
self.bn_theta = abn(planes // ratio) |
|
|
|
|
|
|
|
self.conv_adj = nn.Conv1d(planes // ratio, planes // ratio, kernel_size=1, bias=False) |
|
self.bn_adj = nn.BatchNorm1d(planes // ratio) |
|
|
|
|
|
self.conv_wg = nn.Conv1d(planes // ratio * 2, planes // ratio * 2, kernel_size=1, bias=False) |
|
self.bn_wg = nn.BatchNorm1d(planes // ratio * 2) |
|
|
|
|
|
self.conv3 = nn.Conv2d(planes // ratio * 2, planes, kernel_size=1, bias=False) |
|
self.bn3 = abn(planes) |
|
|
|
self.local = nn.Sequential( |
|
nn.Conv2d(planes, planes, 3, groups=planes, stride=2, padding=1, bias=False), |
|
abn(planes), |
|
nn.Conv2d(planes, planes, 3, groups=planes, stride=2, padding=1, bias=False), |
|
abn(planes), |
|
nn.Conv2d(planes, planes, 3, groups=planes, stride=2, padding=1, bias=False), |
|
abn(planes)) |
|
self.gcn_local_attention = SpatialGCN(planes, abn) |
|
|
|
self.final = nn.Sequential(nn.Conv2d(planes * 2, planes, kernel_size=1, bias=False), |
|
abn(planes)) |
|
|
|
self.gamma1 = nn.Parameter(torch.zeros(1)) |
|
|
|
def to_matrix(self, x): |
|
n, c, h, w = x.size() |
|
x = x.view(n, c, -1) |
|
return x |
|
|
|
def forward(self, feat): |
|
|
|
x = feat |
|
local = self.local(feat) |
|
local = self.gcn_local_attention(local) |
|
local = F.interpolate(local, size=x.size()[2:], mode='bilinear', align_corners=True) |
|
spatial_local_feat = x * local + x |
|
|
|
|
|
x_sqz, b = x, x |
|
|
|
x_sqz = self.phi(x_sqz) |
|
x_sqz = self.bn_phi(x_sqz) |
|
x_sqz = self.to_matrix(x_sqz) |
|
|
|
b = self.theta(b) |
|
b = self.bn_theta(b) |
|
b = self.to_matrix(b) |
|
|
|
|
|
z_idt = torch.matmul(x_sqz, b.transpose(1, 2)) |
|
|
|
|
|
z = z_idt.transpose(1, 2).contiguous() |
|
|
|
z = self.conv_adj(z) |
|
z = self.bn_adj(z) |
|
|
|
z = z.transpose(1, 2).contiguous() |
|
|
|
z += z_idt |
|
|
|
z = self.conv_wg(z) |
|
z = self.bn_wg(z) |
|
|
|
|
|
|
|
y = torch.matmul(z, b) |
|
|
|
n, _, h, w = x.size() |
|
y = y.view(n, -1, h, w) |
|
|
|
y = self.conv3(y) |
|
y = self.bn3(y) |
|
|
|
|
|
|
|
g_out = self.gamma1*y + x |
|
|
|
|
|
out = self.final(torch.cat((spatial_local_feat, g_out), 1)) |
|
|
|
return out |
|
|
|
|
|
class DDualGCNHead(nn.Module): |
|
def __init__(self, inplanes, interplanes, abn=InPlaceABNSync): |
|
super(DDualGCNHead, self).__init__() |
|
self.conva = nn.Sequential(nn.Conv2d(inplanes, interplanes, 3, padding=1, bias=False), |
|
abn(interplanes)) |
|
self.dualgcn = DDualGCN(interplanes, abn) |
|
self.convb = nn.Sequential(nn.Conv2d(interplanes, interplanes, 3, padding=1, bias=False), |
|
abn(interplanes)) |
|
|
|
self.bottleneck = nn.Sequential( |
|
nn.Conv2d(inplanes + interplanes, interplanes, kernel_size=3, padding=1, dilation=1, bias=False), |
|
abn(interplanes) |
|
) |
|
|
|
def forward(self, x): |
|
output = self.conva(x) |
|
output = self.dualgcn(output) |
|
output = self.convb(output) |
|
output = self.bottleneck(torch.cat([x, output], 1)) |
|
return output |
|
|