import torch import torch.nn as nn import torch.nn.functional as F import dgl.function as fn from dgl.nn.pytorch import edge_softmax class intra_message(nn.Module): def __init__(self,node_feat_size, graph_feat_size, dropout): super(intra_message, self).__init__() self.project_edge = nn.Sequential( nn.Dropout(dropout), nn.Linear(2 * node_feat_size, 1), nn.LeakyReLU() ) self.project_node = nn.Sequential( nn.Dropout(dropout), nn.Linear(node_feat_size, graph_feat_size), nn.LeakyReLU() ) self.bn_layer = nn.BatchNorm1d(graph_feat_size) def apply_edges(self, edges): return {'he': torch.cat([edges.dst['hv'], edges.src['hv']], dim=1)} def forward(self,g, node_feats): g = g.local_var() g.ndata['hv'] = node_feats g.apply_edges(self.apply_edges) logits = self.project_edge(g.edata['he']) g.edata['a'] = edge_softmax(g, logits) g.ndata['hv'] = self.project_node(node_feats) g.update_all(fn.src_mul_edge('hv', 'a', 'm'), fn.sum('m', 'c')) return F.elu(g.ndata['c']) class inter_message(nn.Module): def __init__(self,in_dim, out_dim, dropout): super(inter_message, self).__init__() self.project_edges = nn.Sequential( nn.Dropout(dropout), nn.Linear(in_dim, out_dim), nn.LeakyReLU() ) def apply_edges(self, edges): return {'m': self.project_edges(torch.cat([edges.data['e'],edges.src['h'], edges.dst['h']], dim=1))} def forward(self,g, node_feats): g = g.local_var() g.ndata['h'] = node_feats g.update_all(self.apply_edges, fn.mean('m','c')) return F.elu(g.ndata['c']) class update_node_feats(nn.Module): def __init__(self,in_dim, out_dim, dropout): super(update_node_feats, self).__init__() self.gru = nn.GRUCell(out_dim, out_dim) self.project_node = nn.Sequential( nn.Dropout(dropout), nn.Linear(in_dim, out_dim), nn.LeakyReLU() ) self.bn_layer = nn.BatchNorm1d(out_dim) def forward(self, g, node_feats, intra_m, inter_m): g = g.local_var() return self.bn_layer(F.relu(self.gru(self.project_node(torch.cat([node_feats, intra_m, inter_m], dim=1)),node_feats)))