Spaces:
Running
Running
import dgl | |
import dgl.function as fn | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from dgl.nn.pytorch import edge_softmax | |
class AttentiveGRU1(nn.Module): | |
def __init__(self, node_feat_size, edge_feat_size, edge_hidden_size, dropout): | |
super(AttentiveGRU1, self).__init__() | |
self.edge_transform = nn.Sequential( | |
nn.Dropout(dropout), | |
nn.Linear(edge_feat_size, edge_hidden_size) | |
) | |
self.gru = nn.GRUCell(edge_hidden_size, node_feat_size) | |
def forward(self, g, edge_logits, edge_feats, node_feats): | |
g = g.local_var() | |
g.edata['e'] = edge_softmax(g, edge_logits) * self.edge_transform(edge_feats) | |
g.update_all(fn.copy_edge('e', 'm'), fn.sum('m', 'c')) | |
context = F.elu(g.ndata['c']) | |
return F.relu(self.gru(context, node_feats)) | |
class AttentiveGRU2(nn.Module): | |
def __init__(self, node_feat_size, edge_hidden_size, dropout): | |
super(AttentiveGRU2, self).__init__() | |
self.project_node = nn.Sequential( | |
nn.Dropout(dropout), | |
nn.Linear(node_feat_size, edge_hidden_size) | |
) | |
self.gru = nn.GRUCell(edge_hidden_size, node_feat_size) | |
def forward(self, g, edge_logits, node_feats): | |
g = g.local_var() | |
g.edata['a'] = edge_softmax(g, edge_logits) | |
g.ndata['hv'] = self.project_node(node_feats) | |
g.update_all(fn.src_mul_edge('hv', 'a', 'm'), fn.sum('m', 'c')) | |
context = F.elu(g.ndata['c']) | |
return F.relu(self.gru(context, node_feats)) | |
class GetContext(nn.Module): | |
def __init__(self, node_feat_size, edge_feat_size, graph_feat_size, dropout): | |
super(GetContext, self).__init__() | |
self.project_node = nn.Sequential( | |
nn.Linear(node_feat_size, graph_feat_size), | |
nn.LeakyReLU() | |
) | |
self.project_edge1 = nn.Sequential( | |
nn.Linear(node_feat_size + edge_feat_size, graph_feat_size), | |
nn.LeakyReLU() | |
) | |
self.project_edge2 = nn.Sequential( | |
nn.Dropout(dropout), | |
nn.Linear(2 * graph_feat_size, 1), | |
nn.LeakyReLU() | |
) | |
self.attentive_gru = AttentiveGRU1(graph_feat_size, graph_feat_size, | |
graph_feat_size, dropout) | |
def apply_edges1(self, edges): | |
return {'he1': torch.cat([edges.src['hv'], edges.data['he']], dim=1)} | |
def apply_edges2(self, edges): | |
return {'he2': torch.cat([edges.dst['hv_new'], edges.data['he1']], dim=1)} | |
def forward(self, g, node_feats, edge_feats): | |
g = g.local_var() | |
g.ndata['hv'] = node_feats | |
g.ndata['hv_new'] = self.project_node(node_feats) | |
g.edata['he'] = edge_feats | |
g.apply_edges(self.apply_edges1) | |
g.edata['he1'] = self.project_edge1(g.edata['he1']) | |
g.apply_edges(self.apply_edges2) | |
logits = self.project_edge2(g.edata['he2']) | |
return self.attentive_gru(g, logits, g.edata['he1'], g.ndata['hv_new']) | |
class GNNLayer(nn.Module): | |
def __init__(self, node_feat_size, graph_feat_size, dropout): | |
super(GNNLayer, self).__init__() | |
self.project_edge = nn.Sequential( | |
nn.Dropout(dropout), | |
nn.Linear(2 * node_feat_size, 1), | |
nn.LeakyReLU() | |
) | |
self.attentive_gru = AttentiveGRU2(node_feat_size, graph_feat_size, dropout) | |
self.bn_layer = nn.BatchNorm1d(graph_feat_size) | |
def apply_edges(self, edges): | |
return {'he': torch.cat([edges.dst['hv'], edges.src['hv']], dim=1)} | |
def forward(self, g, node_feats): | |
g = g.local_var() | |
g.ndata['hv'] = node_feats | |
g.apply_edges(self.apply_edges) | |
logits = self.project_edge(g.edata['he']) | |
return self.bn_layer(self.attentive_gru(g, logits, node_feats)) | |
class ModifiedAttentiveFPGNNV2(nn.Module): | |
def __init__(self, | |
node_feat_size, | |
edge_feat_size, | |
num_layers=2, | |
graph_feat_size=200, | |
dropout=0., | |
jk='sum'): | |
super(ModifiedAttentiveFPGNNV2, self).__init__() | |
self.jk = jk | |
self.graph_feat_size = graph_feat_size | |
self.num_layers = num_layers | |
self.init_context = GetContext(node_feat_size, edge_feat_size, graph_feat_size, dropout) | |
self.gnn_layers = nn.ModuleList() | |
for _ in range(num_layers - 1): | |
self.gnn_layers.append(GNNLayer(graph_feat_size, graph_feat_size, dropout)) | |
def forward(self, g, Perturb=None): | |
atom_feats = g.ndata['h'].float() | |
bond_feats = g.edata['e'] | |
node_feats = self.init_context(g, atom_feats, bond_feats) | |
if Perturb is not None: | |
node_feats = node_feats + Perturb | |
h_list = [node_feats] | |
for gnn in self.gnn_layers: | |
node_feats = gnn(g, node_feats) | |
h_list.append(node_feats) | |
if self.jk=='sum': | |
h_list = [h.unsqueeze(0) for h in h_list] | |
return torch.sum(torch.cat(h_list, dim=0), dim=0) | |
elif self.jk=='max': | |
h_list = [h.unsqueeze(0) for h in h_list] | |
return torch.max(torch.cat(h_list, dim = 0), dim = 0)[0] | |
elif self.jk=='concat': | |
return torch.cat(h_list, dim = 1) | |
elif self.jk=='last': | |
return h_list[-1] | |
class DTIConvGraph3(nn.Module): | |
def __init__(self, in_dim, out_dim): | |
super(DTIConvGraph3, self).__init__() | |
# the MPL for update the edge state | |
self.mpl = nn.Sequential(nn.Linear(in_dim, out_dim), | |
nn.LeakyReLU(), | |
nn.Linear(out_dim, out_dim), | |
nn.LeakyReLU(), | |
nn.Linear(out_dim, out_dim), | |
nn.LeakyReLU()) | |
def EdgeUpdate(self, edges): | |
return {'e': self.mpl(torch.cat([edges.data['e'],edges.src['h'], edges.dst['h']], dim=1))} | |
def forward(self, bg): | |
with bg.local_scope(): | |
bg.apply_edges(self.EdgeUpdate) | |
return bg.edata['e'] | |
class DTIConvGraph3Layer(nn.Module): | |
def __init__(self, in_dim, out_dim, dropout): # in_dim = graph module1 output dim + 1 | |
super(DTIConvGraph3Layer, self).__init__() | |
# the MPL for update the edge state | |
self.grah_conv = DTIConvGraph3(in_dim, out_dim) | |
self.dropout = nn.Dropout(dropout) | |
self.bn_layer = nn.BatchNorm1d(out_dim) | |
def forward(self, bg): | |
new_feats = self.grah_conv(bg) | |
return self.bn_layer(self.dropout(new_feats)) | |
class DTIConvGraph3_IGN_basic(nn.Module): | |
def __init__(self, in_dim, out_dim): | |
super(DTIConvGraph3_IGN_basic, self).__init__() | |
# the MPL for update the edge state | |
self.mpl = nn.Sequential(nn.Linear(in_dim, out_dim), | |
nn.LeakyReLU(), | |
nn.Linear(out_dim, out_dim), | |
nn.LeakyReLU(), | |
nn.Linear(out_dim, out_dim), | |
nn.LeakyReLU()) | |
def EdgeUpdate(self, edges): | |
return {'e': self.mpl(torch.cat([edges.data['e'], edges.src['h'] + edges.dst['h']], dim=1))} | |
def forward(self, bg): | |
with bg.local_scope(): | |
bg.apply_edges(self.EdgeUpdate) | |
return bg.edata['e'] | |
class DTIConvGraph3Layer_IGN_basic(nn.Module): | |
def __init__(self, in_dim, out_dim, dropout): # in_dim = graph module1 output dim + 1 | |
super(DTIConvGraph3Layer_IGN_basic, self).__init__() | |
# the MPL for update the edge state | |
self.grah_conv = DTIConvGraph3_IGN_basic(in_dim, out_dim) | |
self.dropout = nn.Dropout(dropout) | |
self.bn_layer = nn.BatchNorm1d(out_dim) | |
def forward(self, bg): | |
new_feats = self.grah_conv(bg) | |
return self.bn_layer(self.dropout(new_feats)) | |