File size: 2,384 Bytes
3ad8be1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from dgl.nn.pytorch import edge_softmax

class intra_message(nn.Module):
    def __init__(self,node_feat_size, graph_feat_size, dropout):
        super(intra_message, self).__init__()

        self.project_edge = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(2 * node_feat_size, 1),
            nn.LeakyReLU()
        )
        self.project_node = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(node_feat_size, graph_feat_size),
            nn.LeakyReLU()
        )

        self.bn_layer = nn.BatchNorm1d(graph_feat_size)

    def apply_edges(self, edges):
        return {'he': torch.cat([edges.dst['hv'], edges.src['hv']], dim=1)}

    def forward(self,g, node_feats):
        g = g.local_var()
        g.ndata['hv'] = node_feats
        g.apply_edges(self.apply_edges)
        logits = self.project_edge(g.edata['he'])
        g.edata['a'] = edge_softmax(g, logits)
        g.ndata['hv'] = self.project_node(node_feats)
        g.update_all(fn.src_mul_edge('hv', 'a', 'm'), fn.sum('m', 'c'))

        return F.elu(g.ndata['c'])

class inter_message(nn.Module):
    def __init__(self,in_dim, out_dim, dropout):
        super(inter_message, self).__init__()
        self.project_edges = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(in_dim, out_dim),
            nn.LeakyReLU()
        )
    def apply_edges(self, edges):
        return {'m': self.project_edges(torch.cat([edges.data['e'],edges.src['h'], edges.dst['h']], dim=1))}

    def forward(self,g, node_feats):
        g = g.local_var()
        g.ndata['h'] = node_feats
        g.update_all(self.apply_edges, fn.mean('m','c'))
        return F.elu(g.ndata['c'])

class update_node_feats(nn.Module):
    def __init__(self,in_dim, out_dim, dropout):
        super(update_node_feats, self).__init__()
        self.gru = nn.GRUCell(out_dim, out_dim)
        self.project_node = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(in_dim, out_dim),
            nn.LeakyReLU()
        )
        self.bn_layer = nn.BatchNorm1d(out_dim)

    def forward(self, g, node_feats, intra_m, inter_m):
        g = g.local_var()
        return self.bn_layer(F.relu(self.gru(self.project_node(torch.cat([node_feats, intra_m, inter_m], dim=1)),node_feats)))