Spaces:
Running
Running
File size: 4,684 Bytes
3ad8be1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
class EGNNConv(nn.Module):
def __init__(self, in_size, hidden_size, out_size, edge_feat_size=0):
super(EGNNConv, self).__init__()
self.in_size = in_size
self.hidden_size = hidden_size
self.out_size = out_size
self.edge_feat_size = edge_feat_size
act_fn = nn.SiLU()
# \phi_e
self.edge_mlp = nn.Sequential(
# +1 for the radial feature: ||x_i - x_j||^2
nn.Linear(in_size * 2 + edge_feat_size + 1, hidden_size),
act_fn,
nn.Linear(hidden_size, hidden_size),
act_fn
)
# \phi_h
self.node_mlp = nn.Sequential(
nn.Linear(in_size + hidden_size, hidden_size),
act_fn,
nn.Linear(hidden_size, out_size)
)
# \phi_x
self.coord_mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
act_fn,
nn.Linear(hidden_size, 1, bias=False)
)
def message(self, edges):
"""message function for EGNN"""
# concat features for edge mlp
if self.edge_feat_size > 0:
f = torch.cat(
[edges.src['h'], edges.dst['h'], edges.data['radial'], edges.data['a']],
dim=-1
)
else:
f = torch.cat([edges.src['h'], edges.dst['h'], edges.data['radial']], dim=-1)
msg_h = self.edge_mlp(f)
msg_x = self.coord_mlp(msg_h) * edges.data['x_diff']
return {'msg_x': msg_x, 'msg_h': msg_h}
def forward(self, graph, node_feat, coord_feat, edge_feat=None):
with graph.local_scope():
# node feature
graph.ndata['h'] = node_feat
# coordinate feature
graph.ndata['x'] = coord_feat
# edge feature
if self.edge_feat_size > 0:
assert edge_feat is not None, "Edge features must be provided."
graph.edata['a'] = edge_feat
# get coordinate diff & radial features
graph.apply_edges(fn.u_sub_v('x', 'x', 'x_diff'))
graph.edata['radial'] = graph.edata['x_diff'].square().sum(dim=1).unsqueeze(-1)
# normalize coordinate difference
graph.edata['x_diff'] = graph.edata['x_diff'] / (graph.edata['radial'].sqrt() + 1e-30)
graph.apply_edges(self.message)
graph.update_all(fn.copy_e('msg_x', 'm'), fn.mean('m', 'x_neigh'))
graph.update_all(fn.copy_e('msg_h', 'm'), fn.sum('m', 'h_neigh'))
h_neigh, x_neigh = graph.ndata['h_neigh'], graph.ndata['x_neigh']
h = self.node_mlp(
torch.cat([node_feat, h_neigh], dim=-1)
)
x = coord_feat + x_neigh
return h, x
class EGNN(nn.Module):
def __init__(self, input_node_dim, input_edge_dim, hidden_dim, num_layers, dropout, JK='sum'):
super(EGNN, self).__init__()
self.num_layers = num_layers
# List of MLPs
self.egnn_layers = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
for layer in range(self.num_layers - 1):
if layer == 0:
self.egnn_layers.append(EGNNConv(input_node_dim, hidden_dim, hidden_dim, input_edge_dim))
else:
self.egnn_layers.append(EGNNConv(hidden_dim, hidden_dim, hidden_dim, input_edge_dim))
self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
self.drop = nn.Dropout(dropout)
self.JK = JK
def forward(self, g, Perturb=None):
hidden_rep = []
node_feats = g.ndata.pop('h').float()
edge_feats = g.edata['e']
coord_feats = g.ndata['pos']
for idx, egnn in enumerate(self.egnn_layers):
if idx == 0 and Perturb is not None:
node_feats = node_feats + Perturb
node_feats, coord_feats = egnn(g, node_feats, coord_feats, edge_feats)
node_feats = self.batch_norms[idx](node_feats)
node_feats = F.relu(node_feats)
node_feats = self.drop(node_feats)
hidden_rep.append(node_feats)
if self.JK == 'sum':
hidden_rep = [h.unsqueeze(0) for h in hidden_rep]
return torch.sum(torch.cat(hidden_rep, dim=0), dim=0)
elif self.JK == 'max':
hidden_rep = [h.unsqueeze(0) for h in hidden_rep]
return torch.max(torch.cat(hidden_rep, dim=0), dim=0)[0]
elif self.JK == 'concat':
return torch.cat(hidden_rep, dim=1)
elif self.JK == 'last':
return hidden_rep[-1]
|