FairUP / src /models /CatGCN /gnn_layers.py
erasmopurif's picture
First commit
d2a8669
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch.nn.parameter import Parameter
class BatchAGC(nn.Module):
def __init__(self, in_features, out_features, bias=True):
super(BatchAGC, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(in_features, out_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features))
init.constant_(self.bias, 0)
else:
self.register_parameter('bias', None)
init.xavier_uniform_(self.weight)
def forward(self, x, adj):
expand_weight = self.weight.expand(x.shape[0], -1, -1)
support = torch.bmm(x, expand_weight)
output = torch.bmm(adj, support)
if self.bias is not None:
return output + self.bias
else:
return output
class BatchFiGNN(nn.Module):
def __init__(self, f_in, f_out, out_channels):
super(BatchFiGNN, self).__init__()
# Edge Weights
self.a_src = Parameter(torch.Tensor(f_in, 1))
self.a_dst = Parameter(torch.Tensor(f_in, 1))
self.leaky_relu = nn.LeakyReLU(negative_slope=0.2)
self.softmax = nn.Softmax(dim=-1)
# Transformation
self.w = Parameter(torch.Tensor(f_in, f_out))
self.bias = Parameter(torch.Tensor(f_out))
# State Update by GRU
self.rnn = torch.nn.GRUCell(f_out, f_out, bias=True)
# Attention Pooling
self.mlp_1 = nn.Linear(f_out, out_channels, bias=True)
self.mlp_2 = nn.Linear(f_out, 1, bias=True)
init.xavier_uniform_(self.w)
init.constant_(self.bias, 0)
init.xavier_uniform_(self.a_src)
init.xavier_uniform_(self.a_dst)
def forward(self, h, adj, steps):
bs, n = h.size()[:2]
## Edge Weights
attn_src = torch.matmul(h, self.a_src)
attn_dst = torch.matmul(h, self.a_dst)
attn = attn_src.expand(-1, -1, n) + \
attn_dst.expand(-1, -1, n).permute(0, 2, 1)
attn = self.leaky_relu(attn)
mask = torch.eye(adj.size()[-1]).unsqueeze(0).cuda()
mask = 1 - mask
attn = attn * mask
attn = self.softmax(attn)
## State Update
s = h
for _ in range(steps):
## Transformation
a = torch.matmul(s, self.w)
a = torch.matmul(attn, a) + self.bias
## GRU
s = self.rnn(s.view(-1, s.size()[-1]), a.view(-1, a.size()[-1]))
s = h.view(h.size()) + h
## Attention Pooling
output = self.mlp_1(s)
weight = self.mlp_2(s).permute(0, 2, 1)
output = torch.matmul(weight, output).squeeze()
return output
class BatchGAT(nn.Module):
def __init__(self, n_head, f_in, f_out, attn_dropout, bias=True):
super(BatchGAT, self).__init__()
self.n_head = n_head
self.w = Parameter(torch.Tensor(n_head, f_in, f_out))
self.a_src = Parameter(torch.Tensor(n_head, f_out, 1))
self.a_dst = Parameter(torch.Tensor(n_head, f_out, 1))
self.leaky_relu = nn.LeakyReLU(negative_slope=0.2)
self.softmax = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(attn_dropout)
if bias:
self.bias = Parameter(torch.Tensor(f_out))
init.constant_(self.bias, 0)
else:
self.register_parameter('bias', None)
init.xavier_uniform_(self.w)
init.xavier_uniform_(self.a_src)
init.xavier_uniform_(self.a_dst)
def forward(self, h, adj):
bs, n = h.size()[:2]
h_prime = torch.matmul(h.unsqueeze(1), self.w)
attn_src = torch.matmul(torch.tanh(h_prime), self.a_src)
attn_dst = torch.matmul(torch.tanh(h_prime), self.a_dst)
attn = attn_src.expand(-1, -1, -1, n) + \
attn_dst.expand(-1, -1, -1, n).permute(0, 1, 3, 2)
attn = self.leaky_relu(attn)
attn = self.softmax(attn)
attn = self.dropout(attn)
output = torch.matmul(attn, h_prime)
if self.bias is not None:
return output + self.bias
else:
return output