Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.init as init | |
import torch.nn.functional as F | |
from torch.nn.parameter import Parameter | |
class BatchAGC(nn.Module): | |
def __init__(self, in_features, out_features, bias=True): | |
super(BatchAGC, self).__init__() | |
self.in_features = in_features | |
self.out_features = out_features | |
self.weight = Parameter(torch.Tensor(in_features, out_features)) | |
if bias: | |
self.bias = Parameter(torch.Tensor(out_features)) | |
init.constant_(self.bias, 0) | |
else: | |
self.register_parameter('bias', None) | |
init.xavier_uniform_(self.weight) | |
def forward(self, x, adj): | |
expand_weight = self.weight.expand(x.shape[0], -1, -1) | |
support = torch.bmm(x, expand_weight) | |
output = torch.bmm(adj, support) | |
if self.bias is not None: | |
return output + self.bias | |
else: | |
return output | |
class BatchFiGNN(nn.Module): | |
def __init__(self, f_in, f_out, out_channels): | |
super(BatchFiGNN, self).__init__() | |
# Edge Weights | |
self.a_src = Parameter(torch.Tensor(f_in, 1)) | |
self.a_dst = Parameter(torch.Tensor(f_in, 1)) | |
self.leaky_relu = nn.LeakyReLU(negative_slope=0.2) | |
self.softmax = nn.Softmax(dim=-1) | |
# Transformation | |
self.w = Parameter(torch.Tensor(f_in, f_out)) | |
self.bias = Parameter(torch.Tensor(f_out)) | |
# State Update by GRU | |
self.rnn = torch.nn.GRUCell(f_out, f_out, bias=True) | |
# Attention Pooling | |
self.mlp_1 = nn.Linear(f_out, out_channels, bias=True) | |
self.mlp_2 = nn.Linear(f_out, 1, bias=True) | |
init.xavier_uniform_(self.w) | |
init.constant_(self.bias, 0) | |
init.xavier_uniform_(self.a_src) | |
init.xavier_uniform_(self.a_dst) | |
def forward(self, h, adj, steps): | |
bs, n = h.size()[:2] | |
## Edge Weights | |
attn_src = torch.matmul(h, self.a_src) | |
attn_dst = torch.matmul(h, self.a_dst) | |
attn = attn_src.expand(-1, -1, n) + \ | |
attn_dst.expand(-1, -1, n).permute(0, 2, 1) | |
attn = self.leaky_relu(attn) | |
mask = torch.eye(adj.size()[-1]).unsqueeze(0).cuda() | |
mask = 1 - mask | |
attn = attn * mask | |
attn = self.softmax(attn) | |
## State Update | |
s = h | |
for _ in range(steps): | |
## Transformation | |
a = torch.matmul(s, self.w) | |
a = torch.matmul(attn, a) + self.bias | |
## GRU | |
s = self.rnn(s.view(-1, s.size()[-1]), a.view(-1, a.size()[-1])) | |
s = h.view(h.size()) + h | |
## Attention Pooling | |
output = self.mlp_1(s) | |
weight = self.mlp_2(s).permute(0, 2, 1) | |
output = torch.matmul(weight, output).squeeze() | |
return output | |
class BatchGAT(nn.Module): | |
def __init__(self, n_head, f_in, f_out, attn_dropout, bias=True): | |
super(BatchGAT, self).__init__() | |
self.n_head = n_head | |
self.w = Parameter(torch.Tensor(n_head, f_in, f_out)) | |
self.a_src = Parameter(torch.Tensor(n_head, f_out, 1)) | |
self.a_dst = Parameter(torch.Tensor(n_head, f_out, 1)) | |
self.leaky_relu = nn.LeakyReLU(negative_slope=0.2) | |
self.softmax = nn.Softmax(dim=-1) | |
self.dropout = nn.Dropout(attn_dropout) | |
if bias: | |
self.bias = Parameter(torch.Tensor(f_out)) | |
init.constant_(self.bias, 0) | |
else: | |
self.register_parameter('bias', None) | |
init.xavier_uniform_(self.w) | |
init.xavier_uniform_(self.a_src) | |
init.xavier_uniform_(self.a_dst) | |
def forward(self, h, adj): | |
bs, n = h.size()[:2] | |
h_prime = torch.matmul(h.unsqueeze(1), self.w) | |
attn_src = torch.matmul(torch.tanh(h_prime), self.a_src) | |
attn_dst = torch.matmul(torch.tanh(h_prime), self.a_dst) | |
attn = attn_src.expand(-1, -1, -1, n) + \ | |
attn_dst.expand(-1, -1, -1, n).permute(0, 1, 3, 2) | |
attn = self.leaky_relu(attn) | |
attn = self.softmax(attn) | |
attn = self.dropout(attn) | |
output = torch.matmul(attn, h_prime) | |
if self.bias is not None: | |
return output + self.bias | |
else: | |
return output | |