Sam / VideoToNPZ /model /global_attention.py
Amanpreet
added 2
1cdc47e
from __future__ import absolute_import, division
import torch
from torch import nn
class GlobalGraph(nn.Module):
""""
Global graph attention layer
"""
def __init__(self, adj, in_channels, inter_channels=None):
super(GlobalGraph, self).__init__()
self.adj = adj
self.in_channels = in_channels
self.inter_channels = inter_channels
self.softmax = nn.Softmax(dim=-1)
self.relu = nn.ReLU(inplace=True)
self.leakyrelu = nn.LeakyReLU(0.2)
if self.inter_channels == self.in_channels // 2:
self.g_channels = self.in_channels
else:
self.g_channels = self.inter_channels
assert self.inter_channels > 0
self.g = nn.Conv1d(in_channels=self.in_channels, out_channels=self.g_channels,
kernel_size=1, stride=1, padding=0)
self.theta = nn.Conv1d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
self.phi = nn.Conv1d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
adj_shape = self.adj.shape
self.C_k = nn.Parameter(torch.zeros(adj_shape, dtype=torch.float))
self.concat_project = nn.Sequential(
nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False),
)
nn.init.kaiming_normal_(self.concat_project[0].weight)
nn.init.kaiming_normal_(self.g.weight)
nn.init.constant_(self.g.bias, 0)
nn.init.kaiming_normal_(self.theta.weight)
nn.init.constant_(self.theta.bias, 0)
nn.init.kaiming_normal_(self.phi.weight)
nn.init.constant_(self.phi.bias, 0)
def forward(self, x):
batch_size = x.size(0) # x: (B*T, C, N)
# g_x: (B*T, N, C/k)
g_x = self.g(x).view(batch_size, self.g_channels, -1)
g_x = g_x.permute(0, 2, 1)
# (B*T, C/k, N, 1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1)
# (B*T, C/k, 1, N)
phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1)
# h: N, w: N
h = theta_x.size(2)
w = phi_x.size(3)
theta_x = theta_x.expand(-1, -1, -1, w) # (B*T, C/k, N, N)
phi_x = phi_x.expand(-1, -1, h, -1)
# concat_feature: (B*T, C/k, N, N)
concat_feature = torch.cat([theta_x, phi_x], dim=1)
f = self.concat_project(concat_feature) # (B*T, 1, N, N)
b, _, h, w = f.size()
attention = self.leakyrelu(f.view(b, h, w)) # (B*T, N, N) attention:B_k
attention = torch.add(self.softmax(attention), self.C_k)
# y: (B*T, C/k, N)
y = torch.matmul(attention, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.g_channels, *x.size()[2:])
return y
class MultiGlobalGraph(nn.Module):
def __init__(self, adj, in_channels, inter_channels, dropout=None):
super(MultiGlobalGraph, self).__init__()
self.num_non_local = in_channels // inter_channels
attentions = [GlobalGraph(adj, in_channels, inter_channels) for _ in range(self.num_non_local)]
self.attentions = nn.ModuleList(attentions)
self.cat_conv = nn.Conv2d(in_channels, in_channels, 1, bias=False)
self.cat_bn = nn.BatchNorm2d(in_channels, momentum=0.1)
self.relu = nn.ReLU(inplace=True)
if dropout is not None:
self.dropout = nn.Dropout(dropout)
else:
self.dropout = None
def forward(self, x):
# x: (B, T, K, C) --> (B*T, K, C)
x_size = x.shape
x = x.contiguous()
x = x.view(-1, *x_size[2:])
# x: (B*T, C, K)
x = x.permute(0, 2, 1)
x = torch.cat([self.attentions[i](x) for i in range(len(self.attentions))], dim=1)
# x: (B*T, C, K) --> (B*T, K, C)
x = x.permute(0, 2, 1).contiguous()
# x = torch.matmul(x, self.W)
# x: (B*T, K, C) --> (B, T, K, C)
x = x.view(*x_size)
# x: (B, T, K, C) --> (B, C, T, K)
x = x.permute(0, 3, 1, 2)
x = self.relu(self.cat_bn(self.cat_conv(x)))
if self.dropout is not None:
x = self.dropout(x)
# x: (B, C, T, K) --> (B, T, K, C)
x = x.permute(0, 2, 3, 1)
return x
class SingleGlobalGraph(nn.Module):
def __init__(self, adj, in_channels, output_channels, dropout=None):
super(SingleGlobalGraph, self).__init__()
self.attentions = GlobalGraph(adj, in_channels, output_channels//2)
self.bn = nn.BatchNorm2d(in_channels, momentum=0.1)
self.relu = nn.ReLU(inplace=True)
if dropout is not None:
self.dropout = nn.Dropout(dropout)
else:
self.dropout = None
def forward(self, x):
# x: (B, T, K, C) --> (B*T, K, C)
x_size = x.shape
x = x.contiguous()
x = x.view(-1, *x_size[2:])
# x: (B*T, C, K)
x = x.permute(0, 2, 1)
x = self.attentions(x)
# x: (B*T, C, K) --> (B*T, K, C)
x = x.permute(0, 2, 1).contiguous()
# x = torch.matmul(x, self.W)
# x: (B*T, K, C) --> (B, T, K, C)
x = x.view(*x_size)
# x: (B, T, K, C) --> (B, C, T, K)
x = x.permute(0, 3, 1, 2)
x = self.relu(self.bn(x))
if self.dropout is not None:
x = self.dropout(x)
# x: (B, C, T, K) --> (B, T, K, C)
x = x.permute(0, 2, 3, 1)
return x