mbp / UltraFlow /layers /attentivefp.py
jiaxianustc's picture
test
3ad8be1
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import edge_softmax
class AttentiveGRU1(nn.Module):
def __init__(self, node_feat_size, edge_feat_size, edge_hidden_size, dropout):
super(AttentiveGRU1, self).__init__()
self.edge_transform = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(edge_feat_size, edge_hidden_size)
)
self.gru = nn.GRUCell(edge_hidden_size, node_feat_size)
def forward(self, g, edge_logits, edge_feats, node_feats):
g = g.local_var()
g.edata['e'] = edge_softmax(g, edge_logits) * self.edge_transform(edge_feats)
g.update_all(fn.copy_edge('e', 'm'), fn.sum('m', 'c'))
context = F.elu(g.ndata['c'])
return F.relu(self.gru(context, node_feats))
class AttentiveGRU2(nn.Module):
def __init__(self, node_feat_size, edge_hidden_size, dropout):
super(AttentiveGRU2, self).__init__()
self.project_node = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(node_feat_size, edge_hidden_size)
)
self.gru = nn.GRUCell(edge_hidden_size, node_feat_size)
def forward(self, g, edge_logits, node_feats):
g = g.local_var()
g.edata['a'] = edge_softmax(g, edge_logits)
g.ndata['hv'] = self.project_node(node_feats)
g.update_all(fn.src_mul_edge('hv', 'a', 'm'), fn.sum('m', 'c'))
context = F.elu(g.ndata['c'])
return F.relu(self.gru(context, node_feats))
class GetContext(nn.Module):
def __init__(self, node_feat_size, edge_feat_size, graph_feat_size, dropout):
super(GetContext, self).__init__()
self.project_node = nn.Sequential(
nn.Linear(node_feat_size, graph_feat_size),
nn.LeakyReLU()
)
self.project_edge1 = nn.Sequential(
nn.Linear(node_feat_size + edge_feat_size, graph_feat_size),
nn.LeakyReLU()
)
self.project_edge2 = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(2 * graph_feat_size, 1),
nn.LeakyReLU()
)
self.attentive_gru = AttentiveGRU1(graph_feat_size, graph_feat_size,
graph_feat_size, dropout)
def apply_edges1(self, edges):
return {'he1': torch.cat([edges.src['hv'], edges.data['he']], dim=1)}
def apply_edges2(self, edges):
return {'he2': torch.cat([edges.dst['hv_new'], edges.data['he1']], dim=1)}
def forward(self, g, node_feats, edge_feats):
g = g.local_var()
g.ndata['hv'] = node_feats
g.ndata['hv_new'] = self.project_node(node_feats)
g.edata['he'] = edge_feats
g.apply_edges(self.apply_edges1)
g.edata['he1'] = self.project_edge1(g.edata['he1'])
g.apply_edges(self.apply_edges2)
logits = self.project_edge2(g.edata['he2'])
return self.attentive_gru(g, logits, g.edata['he1'], g.ndata['hv_new'])
class GNNLayer(nn.Module):
def __init__(self, node_feat_size, graph_feat_size, dropout):
super(GNNLayer, self).__init__()
self.project_edge = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(2 * node_feat_size, 1),
nn.LeakyReLU()
)
self.attentive_gru = AttentiveGRU2(node_feat_size, graph_feat_size, dropout)
self.bn_layer = nn.BatchNorm1d(graph_feat_size)
def apply_edges(self, edges):
return {'he': torch.cat([edges.dst['hv'], edges.src['hv']], dim=1)}
def forward(self, g, node_feats):
g = g.local_var()
g.ndata['hv'] = node_feats
g.apply_edges(self.apply_edges)
logits = self.project_edge(g.edata['he'])
return self.bn_layer(self.attentive_gru(g, logits, node_feats))
class ModifiedAttentiveFPGNNV2(nn.Module):
def __init__(self,
node_feat_size,
edge_feat_size,
num_layers=2,
graph_feat_size=200,
dropout=0.,
jk='sum'):
super(ModifiedAttentiveFPGNNV2, self).__init__()
self.jk = jk
self.graph_feat_size = graph_feat_size
self.num_layers = num_layers
self.init_context = GetContext(node_feat_size, edge_feat_size, graph_feat_size, dropout)
self.gnn_layers = nn.ModuleList()
for _ in range(num_layers - 1):
self.gnn_layers.append(GNNLayer(graph_feat_size, graph_feat_size, dropout))
def forward(self, g, Perturb=None):
atom_feats = g.ndata['h'].float()
bond_feats = g.edata['e']
node_feats = self.init_context(g, atom_feats, bond_feats)
if Perturb is not None:
node_feats = node_feats + Perturb
h_list = [node_feats]
for gnn in self.gnn_layers:
node_feats = gnn(g, node_feats)
h_list.append(node_feats)
if self.jk=='sum':
h_list = [h.unsqueeze(0) for h in h_list]
return torch.sum(torch.cat(h_list, dim=0), dim=0)
elif self.jk=='max':
h_list = [h.unsqueeze(0) for h in h_list]
return torch.max(torch.cat(h_list, dim = 0), dim = 0)[0]
elif self.jk=='concat':
return torch.cat(h_list, dim = 1)
elif self.jk=='last':
return h_list[-1]
class DTIConvGraph3(nn.Module):
def __init__(self, in_dim, out_dim):
super(DTIConvGraph3, self).__init__()
# the MPL for update the edge state
self.mpl = nn.Sequential(nn.Linear(in_dim, out_dim),
nn.LeakyReLU(),
nn.Linear(out_dim, out_dim),
nn.LeakyReLU(),
nn.Linear(out_dim, out_dim),
nn.LeakyReLU())
def EdgeUpdate(self, edges):
return {'e': self.mpl(torch.cat([edges.data['e'],edges.src['h'], edges.dst['h']], dim=1))}
def forward(self, bg):
with bg.local_scope():
bg.apply_edges(self.EdgeUpdate)
return bg.edata['e']
class DTIConvGraph3Layer(nn.Module):
def __init__(self, in_dim, out_dim, dropout): # in_dim = graph module1 output dim + 1
super(DTIConvGraph3Layer, self).__init__()
# the MPL for update the edge state
self.grah_conv = DTIConvGraph3(in_dim, out_dim)
self.dropout = nn.Dropout(dropout)
self.bn_layer = nn.BatchNorm1d(out_dim)
def forward(self, bg):
new_feats = self.grah_conv(bg)
return self.bn_layer(self.dropout(new_feats))
class DTIConvGraph3_IGN_basic(nn.Module):
def __init__(self, in_dim, out_dim):
super(DTIConvGraph3_IGN_basic, self).__init__()
# the MPL for update the edge state
self.mpl = nn.Sequential(nn.Linear(in_dim, out_dim),
nn.LeakyReLU(),
nn.Linear(out_dim, out_dim),
nn.LeakyReLU(),
nn.Linear(out_dim, out_dim),
nn.LeakyReLU())
def EdgeUpdate(self, edges):
return {'e': self.mpl(torch.cat([edges.data['e'], edges.src['h'] + edges.dst['h']], dim=1))}
def forward(self, bg):
with bg.local_scope():
bg.apply_edges(self.EdgeUpdate)
return bg.edata['e']
class DTIConvGraph3Layer_IGN_basic(nn.Module):
def __init__(self, in_dim, out_dim, dropout): # in_dim = graph module1 output dim + 1
super(DTIConvGraph3Layer_IGN_basic, self).__init__()
# the MPL for update the edge state
self.grah_conv = DTIConvGraph3_IGN_basic(in_dim, out_dim)
self.dropout = nn.Dropout(dropout)
self.bn_layer = nn.BatchNorm1d(out_dim)
def forward(self, bg):
new_feats = self.grah_conv(bg)
return self.bn_layer(self.dropout(new_feats))