jiaxianustc's picture
test
3ad8be1
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
class EGNNConv(nn.Module):
def __init__(self, in_size, hidden_size, out_size, edge_feat_size=0):
super(EGNNConv, self).__init__()
self.in_size = in_size
self.hidden_size = hidden_size
self.out_size = out_size
self.edge_feat_size = edge_feat_size
act_fn = nn.SiLU()
# \phi_e
self.edge_mlp = nn.Sequential(
# +1 for the radial feature: ||x_i - x_j||^2
nn.Linear(in_size * 2 + edge_feat_size + 1, hidden_size),
act_fn,
nn.Linear(hidden_size, hidden_size),
act_fn
)
# \phi_h
self.node_mlp = nn.Sequential(
nn.Linear(in_size + hidden_size, hidden_size),
act_fn,
nn.Linear(hidden_size, out_size)
)
# \phi_x
self.coord_mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
act_fn,
nn.Linear(hidden_size, 1, bias=False)
)
def message(self, edges):
"""message function for EGNN"""
# concat features for edge mlp
if self.edge_feat_size > 0:
f = torch.cat(
[edges.src['h'], edges.dst['h'], edges.data['radial'], edges.data['a']],
dim=-1
)
else:
f = torch.cat([edges.src['h'], edges.dst['h'], edges.data['radial']], dim=-1)
msg_h = self.edge_mlp(f)
msg_x = self.coord_mlp(msg_h) * edges.data['x_diff']
return {'msg_x': msg_x, 'msg_h': msg_h}
def forward(self, graph, node_feat, coord_feat, edge_feat=None):
with graph.local_scope():
# node feature
graph.ndata['h'] = node_feat
# coordinate feature
graph.ndata['x'] = coord_feat
# edge feature
if self.edge_feat_size > 0:
assert edge_feat is not None, "Edge features must be provided."
graph.edata['a'] = edge_feat
# get coordinate diff & radial features
graph.apply_edges(fn.u_sub_v('x', 'x', 'x_diff'))
graph.edata['radial'] = graph.edata['x_diff'].square().sum(dim=1).unsqueeze(-1)
# normalize coordinate difference
graph.edata['x_diff'] = graph.edata['x_diff'] / (graph.edata['radial'].sqrt() + 1e-30)
graph.apply_edges(self.message)
graph.update_all(fn.copy_e('msg_x', 'm'), fn.mean('m', 'x_neigh'))
graph.update_all(fn.copy_e('msg_h', 'm'), fn.sum('m', 'h_neigh'))
h_neigh, x_neigh = graph.ndata['h_neigh'], graph.ndata['x_neigh']
h = self.node_mlp(
torch.cat([node_feat, h_neigh], dim=-1)
)
x = coord_feat + x_neigh
return h, x
class EGNN(nn.Module):
def __init__(self, input_node_dim, input_edge_dim, hidden_dim, num_layers, dropout, JK='sum'):
super(EGNN, self).__init__()
self.num_layers = num_layers
# List of MLPs
self.egnn_layers = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
for layer in range(self.num_layers - 1):
if layer == 0:
self.egnn_layers.append(EGNNConv(input_node_dim, hidden_dim, hidden_dim, input_edge_dim))
else:
self.egnn_layers.append(EGNNConv(hidden_dim, hidden_dim, hidden_dim, input_edge_dim))
self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
self.drop = nn.Dropout(dropout)
self.JK = JK
def forward(self, g, Perturb=None):
hidden_rep = []
node_feats = g.ndata.pop('h').float()
edge_feats = g.edata['e']
coord_feats = g.ndata['pos']
for idx, egnn in enumerate(self.egnn_layers):
if idx == 0 and Perturb is not None:
node_feats = node_feats + Perturb
node_feats, coord_feats = egnn(g, node_feats, coord_feats, edge_feats)
node_feats = self.batch_norms[idx](node_feats)
node_feats = F.relu(node_feats)
node_feats = self.drop(node_feats)
hidden_rep.append(node_feats)
if self.JK == 'sum':
hidden_rep = [h.unsqueeze(0) for h in hidden_rep]
return torch.sum(torch.cat(hidden_rep, dim=0), dim=0)
elif self.JK == 'max':
hidden_rep = [h.unsqueeze(0) for h in hidden_rep]
return torch.max(torch.cat(hidden_rep, dim=0), dim=0)[0]
elif self.JK == 'concat':
return torch.cat(hidden_rep, dim=1)
elif self.JK == 'last':
return hidden_rep[-1]