m3face's picture
Adding files
332190f
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
"""
@Author : Qingping Zheng
@Contact : [email protected]
@File : ddgcn.py
@Time : 10/01/21 00:00 PM
@Desc :
@License : Licensed under the Apache License, Version 2.0 (the "License");
@Copyright : Copyright 2022 The Authors. All Rights Reserved.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn.functional as F
import torch.nn as nn
from inplace_abn import InPlaceABNSync
class SpatialGCN(nn.Module):
def __init__(self, plane, abn=InPlaceABNSync):
super(SpatialGCN, self).__init__()
inter_plane = plane // 2
self.node_k = nn.Conv2d(plane, inter_plane, kernel_size=1)
self.node_v = nn.Conv2d(plane, inter_plane, kernel_size=1)
self.node_q = nn.Conv2d(plane, inter_plane, kernel_size=1)
self.conv_wg = nn.Conv1d(inter_plane, inter_plane, kernel_size=1, bias=False)
self.bn_wg = nn.BatchNorm1d(inter_plane)
self.softmax = nn.Softmax(dim=2)
self.out = nn.Sequential(nn.Conv2d(inter_plane, plane, kernel_size=1),
abn(plane))
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
# b, c, h, w = x.size()
node_k = self.node_k(x)
node_v = self.node_v(x)
node_q = self.node_q(x)
b,c,h,w = node_k.size()
node_k = node_k.view(b, c, -1).permute(0, 2, 1)
node_q = node_q.view(b, c, -1)
node_v = node_v.view(b, c, -1).permute(0, 2, 1)
# A = k * q
# AV = k * q * v
# AVW = k *(q *v) * w
AV = torch.bmm(node_q,node_v)
AV = self.softmax(AV)
AV = torch.bmm(node_k, AV)
AV = AV.transpose(1, 2).contiguous()
AVW = self.conv_wg(AV)
AVW = self.bn_wg(AVW)
AVW = AVW.view(b, c, h, -1)
# out = F.relu_(self.out(AVW) + x)
out = self.gamma * self.out(AVW) + x
return out
class DDualGCN(nn.Module):
"""
Feature GCN with coordinate GCN
"""
def __init__(self, planes, abn=InPlaceABNSync, ratio=4):
super(DDualGCN, self).__init__()
self.phi = nn.Conv2d(planes, planes // ratio * 2, kernel_size=1, bias=False)
self.bn_phi = abn(planes // ratio * 2)
self.theta = nn.Conv2d(planes, planes // ratio, kernel_size=1, bias=False)
self.bn_theta = abn(planes // ratio)
# Interaction Space
# Adjacency Matrix: (-)A_g
self.conv_adj = nn.Conv1d(planes // ratio, planes // ratio, kernel_size=1, bias=False)
self.bn_adj = nn.BatchNorm1d(planes // ratio)
# State Update Function: W_g
self.conv_wg = nn.Conv1d(planes // ratio * 2, planes // ratio * 2, kernel_size=1, bias=False)
self.bn_wg = nn.BatchNorm1d(planes // ratio * 2)
# last fc
self.conv3 = nn.Conv2d(planes // ratio * 2, planes, kernel_size=1, bias=False)
self.bn3 = abn(planes)
self.local = nn.Sequential(
nn.Conv2d(planes, planes, 3, groups=planes, stride=2, padding=1, bias=False),
abn(planes),
nn.Conv2d(planes, planes, 3, groups=planes, stride=2, padding=1, bias=False),
abn(planes),
nn.Conv2d(planes, planes, 3, groups=planes, stride=2, padding=1, bias=False),
abn(planes))
self.gcn_local_attention = SpatialGCN(planes, abn)
self.final = nn.Sequential(nn.Conv2d(planes * 2, planes, kernel_size=1, bias=False),
abn(planes))
self.gamma1 = nn.Parameter(torch.zeros(1))
def to_matrix(self, x):
n, c, h, w = x.size()
x = x.view(n, c, -1)
return x
def forward(self, feat):
# # # # Local # # # #
x = feat
local = self.local(feat)
local = self.gcn_local_attention(local)
local = F.interpolate(local, size=x.size()[2:], mode='bilinear', align_corners=True)
spatial_local_feat = x * local + x
# # # # Projection Space # # # #
x_sqz, b = x, x
x_sqz = self.phi(x_sqz)
x_sqz = self.bn_phi(x_sqz)
x_sqz = self.to_matrix(x_sqz)
b = self.theta(b)
b = self.bn_theta(b)
b = self.to_matrix(b)
# Project
z_idt = torch.matmul(x_sqz, b.transpose(1, 2)) # channel
# # # # Interaction Space # # # #
z = z_idt.transpose(1, 2).contiguous()
z = self.conv_adj(z)
z = self.bn_adj(z)
z = z.transpose(1, 2).contiguous()
# Laplacian smoothing: (I - A_g)Z => Z - A_gZ
z += z_idt
z = self.conv_wg(z)
z = self.bn_wg(z)
# # # # Re-projection Space # # # #
# Re-project
y = torch.matmul(z, b)
n, _, h, w = x.size()
y = y.view(n, -1, h, w)
y = self.conv3(y)
y = self.bn3(y)
# g_out = x + y
# g_out = F.relu_(x+y)
g_out = self.gamma1*y + x
# cat or sum, nearly the same results
out = self.final(torch.cat((spatial_local_feat, g_out), 1))
return out
class DDualGCNHead(nn.Module):
def __init__(self, inplanes, interplanes, abn=InPlaceABNSync):
super(DDualGCNHead, self).__init__()
self.conva = nn.Sequential(nn.Conv2d(inplanes, interplanes, 3, padding=1, bias=False),
abn(interplanes))
self.dualgcn = DDualGCN(interplanes, abn)
self.convb = nn.Sequential(nn.Conv2d(interplanes, interplanes, 3, padding=1, bias=False),
abn(interplanes))
self.bottleneck = nn.Sequential(
nn.Conv2d(inplanes + interplanes, interplanes, kernel_size=3, padding=1, dilation=1, bias=False),
abn(interplanes)
)
def forward(self, x):
output = self.conva(x)
output = self.dualgcn(output)
output = self.convb(output)
output = self.bottleneck(torch.cat([x, output], 1))
return output