File size: 4,223 Bytes
d2a8669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch.nn.parameter import Parameter

class BatchAGC(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(BatchAGC, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
            init.constant_(self.bias, 0)
        else:
            self.register_parameter('bias', None)
        init.xavier_uniform_(self.weight)

    def forward(self, x, adj):
        expand_weight = self.weight.expand(x.shape[0], -1, -1)
        support = torch.bmm(x, expand_weight)
        output = torch.bmm(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

class BatchFiGNN(nn.Module):
    def __init__(self, f_in, f_out, out_channels):
        super(BatchFiGNN, self).__init__()
        # Edge Weights
        self.a_src = Parameter(torch.Tensor(f_in, 1))
        self.a_dst = Parameter(torch.Tensor(f_in, 1))
        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2)
        self.softmax = nn.Softmax(dim=-1)
        # Transformation
        self.w = Parameter(torch.Tensor(f_in, f_out))
        self.bias = Parameter(torch.Tensor(f_out))
        # State Update by GRU
        self.rnn = torch.nn.GRUCell(f_out, f_out, bias=True)
        # Attention Pooling
        self.mlp_1 = nn.Linear(f_out, out_channels, bias=True)
        self.mlp_2 = nn.Linear(f_out, 1, bias=True)

        init.xavier_uniform_(self.w)
        init.constant_(self.bias, 0)
        init.xavier_uniform_(self.a_src)
        init.xavier_uniform_(self.a_dst)

    def forward(self, h, adj, steps):
        bs, n = h.size()[:2]
        ## Edge Weights  
        attn_src = torch.matmul(h, self.a_src)
        attn_dst = torch.matmul(h, self.a_dst)
        attn = attn_src.expand(-1, -1, n) + \
            attn_dst.expand(-1, -1, n).permute(0, 2, 1)
        attn = self.leaky_relu(attn)
        mask = torch.eye(adj.size()[-1]).unsqueeze(0).cuda()
        mask = 1 - mask
        attn = attn * mask
        attn = self.softmax(attn)
        ## State Update
        s = h
        for _ in range(steps):
            ## Transformation
            a = torch.matmul(s, self.w)
            a = torch.matmul(attn, a) + self.bias
            ## GRU
            s = self.rnn(s.view(-1, s.size()[-1]), a.view(-1, a.size()[-1]))
            s = h.view(h.size()) + h
        ## Attention Pooling
        output = self.mlp_1(s)
        weight = self.mlp_2(s).permute(0, 2, 1)
        output = torch.matmul(weight, output).squeeze()
        return output

class BatchGAT(nn.Module):
    def __init__(self, n_head, f_in, f_out, attn_dropout, bias=True):
        super(BatchGAT, self).__init__()
        self.n_head = n_head
        self.w = Parameter(torch.Tensor(n_head, f_in, f_out))
        self.a_src = Parameter(torch.Tensor(n_head, f_out, 1))
        self.a_dst = Parameter(torch.Tensor(n_head, f_out, 1))

        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2)
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(attn_dropout)
        if bias:
            self.bias = Parameter(torch.Tensor(f_out))
            init.constant_(self.bias, 0)
        else:
            self.register_parameter('bias', None)

        init.xavier_uniform_(self.w)
        init.xavier_uniform_(self.a_src)
        init.xavier_uniform_(self.a_dst)

    def forward(self, h, adj):
        bs, n = h.size()[:2]
        h_prime = torch.matmul(h.unsqueeze(1), self.w)
        attn_src = torch.matmul(torch.tanh(h_prime), self.a_src)
        attn_dst = torch.matmul(torch.tanh(h_prime), self.a_dst)
        attn = attn_src.expand(-1, -1, -1, n) + \
            attn_dst.expand(-1, -1, -1, n).permute(0, 1, 3, 2)
        attn = self.leaky_relu(attn)
        attn = self.softmax(attn)
        attn = self.dropout(attn)
        output = torch.matmul(attn, h_prime)
        if self.bias is not None:
            return output + self.bias
        else:
            return output