Spaces:
Runtime error
Runtime error
from __future__ import absolute_import, division | |
import torch | |
from torch import nn | |
class GlobalGraph(nn.Module): | |
"""" | |
Global graph attention layer | |
""" | |
def __init__(self, adj, in_channels, inter_channels=None): | |
super(GlobalGraph, self).__init__() | |
self.adj = adj | |
self.in_channels = in_channels | |
self.inter_channels = inter_channels | |
self.softmax = nn.Softmax(dim=-1) | |
self.relu = nn.ReLU(inplace=True) | |
self.leakyrelu = nn.LeakyReLU(0.2) | |
if self.inter_channels == self.in_channels // 2: | |
self.g_channels = self.in_channels | |
else: | |
self.g_channels = self.inter_channels | |
assert self.inter_channels > 0 | |
self.g = nn.Conv1d(in_channels=self.in_channels, out_channels=self.g_channels, | |
kernel_size=1, stride=1, padding=0) | |
self.theta = nn.Conv1d(in_channels=self.in_channels, out_channels=self.inter_channels, | |
kernel_size=1, stride=1, padding=0) | |
self.phi = nn.Conv1d(in_channels=self.in_channels, out_channels=self.inter_channels, | |
kernel_size=1, stride=1, padding=0) | |
adj_shape = self.adj.shape | |
self.C_k = nn.Parameter(torch.zeros(adj_shape, dtype=torch.float)) | |
self.concat_project = nn.Sequential( | |
nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False), | |
) | |
nn.init.kaiming_normal_(self.concat_project[0].weight) | |
nn.init.kaiming_normal_(self.g.weight) | |
nn.init.constant_(self.g.bias, 0) | |
nn.init.kaiming_normal_(self.theta.weight) | |
nn.init.constant_(self.theta.bias, 0) | |
nn.init.kaiming_normal_(self.phi.weight) | |
nn.init.constant_(self.phi.bias, 0) | |
def forward(self, x): | |
batch_size = x.size(0) # x: (B*T, C, N) | |
# g_x: (B*T, N, C/k) | |
g_x = self.g(x).view(batch_size, self.g_channels, -1) | |
g_x = g_x.permute(0, 2, 1) | |
# (B*T, C/k, N, 1) | |
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1) | |
# (B*T, C/k, 1, N) | |
phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1) | |
# h: N, w: N | |
h = theta_x.size(2) | |
w = phi_x.size(3) | |
theta_x = theta_x.expand(-1, -1, -1, w) # (B*T, C/k, N, N) | |
phi_x = phi_x.expand(-1, -1, h, -1) | |
# concat_feature: (B*T, C/k, N, N) | |
concat_feature = torch.cat([theta_x, phi_x], dim=1) | |
f = self.concat_project(concat_feature) # (B*T, 1, N, N) | |
b, _, h, w = f.size() | |
attention = self.leakyrelu(f.view(b, h, w)) # (B*T, N, N) attention:B_k | |
attention = torch.add(self.softmax(attention), self.C_k) | |
# y: (B*T, C/k, N) | |
y = torch.matmul(attention, g_x) | |
y = y.permute(0, 2, 1).contiguous() | |
y = y.view(batch_size, self.g_channels, *x.size()[2:]) | |
return y | |
class MultiGlobalGraph(nn.Module): | |
def __init__(self, adj, in_channels, inter_channels, dropout=None): | |
super(MultiGlobalGraph, self).__init__() | |
self.num_non_local = in_channels // inter_channels | |
attentions = [GlobalGraph(adj, in_channels, inter_channels) for _ in range(self.num_non_local)] | |
self.attentions = nn.ModuleList(attentions) | |
self.cat_conv = nn.Conv2d(in_channels, in_channels, 1, bias=False) | |
self.cat_bn = nn.BatchNorm2d(in_channels, momentum=0.1) | |
self.relu = nn.ReLU(inplace=True) | |
if dropout is not None: | |
self.dropout = nn.Dropout(dropout) | |
else: | |
self.dropout = None | |
def forward(self, x): | |
# x: (B, T, K, C) --> (B*T, K, C) | |
x_size = x.shape | |
x = x.contiguous() | |
x = x.view(-1, *x_size[2:]) | |
# x: (B*T, C, K) | |
x = x.permute(0, 2, 1) | |
x = torch.cat([self.attentions[i](x) for i in range(len(self.attentions))], dim=1) | |
# x: (B*T, C, K) --> (B*T, K, C) | |
x = x.permute(0, 2, 1).contiguous() | |
# x = torch.matmul(x, self.W) | |
# x: (B*T, K, C) --> (B, T, K, C) | |
x = x.view(*x_size) | |
# x: (B, T, K, C) --> (B, C, T, K) | |
x = x.permute(0, 3, 1, 2) | |
x = self.relu(self.cat_bn(self.cat_conv(x))) | |
if self.dropout is not None: | |
x = self.dropout(x) | |
# x: (B, C, T, K) --> (B, T, K, C) | |
x = x.permute(0, 2, 3, 1) | |
return x | |
class SingleGlobalGraph(nn.Module): | |
def __init__(self, adj, in_channels, output_channels, dropout=None): | |
super(SingleGlobalGraph, self).__init__() | |
self.attentions = GlobalGraph(adj, in_channels, output_channels//2) | |
self.bn = nn.BatchNorm2d(in_channels, momentum=0.1) | |
self.relu = nn.ReLU(inplace=True) | |
if dropout is not None: | |
self.dropout = nn.Dropout(dropout) | |
else: | |
self.dropout = None | |
def forward(self, x): | |
# x: (B, T, K, C) --> (B*T, K, C) | |
x_size = x.shape | |
x = x.contiguous() | |
x = x.view(-1, *x_size[2:]) | |
# x: (B*T, C, K) | |
x = x.permute(0, 2, 1) | |
x = self.attentions(x) | |
# x: (B*T, C, K) --> (B*T, K, C) | |
x = x.permute(0, 2, 1).contiguous() | |
# x = torch.matmul(x, self.W) | |
# x: (B*T, K, C) --> (B, T, K, C) | |
x = x.view(*x_size) | |
# x: (B, T, K, C) --> (B, C, T, K) | |
x = x.permute(0, 3, 1, 2) | |
x = self.relu(self.bn(x)) | |
if self.dropout is not None: | |
x = self.dropout(x) | |
# x: (B, C, T, K) --> (B, T, K, C) | |
x = x.permute(0, 2, 3, 1) | |
return x | |