M3Site / model /egnn /network.py
anonymousforpaper's picture
Upload 103 files
224a33f verified
import torch
import torch.nn as nn
from model.egnn import EGNN_Sparse
from model.egnn.utils import get_edge_feature_dims, get_node_feature_dims
from utils.util_functions import get_emb_dim
class nodeEncoder(torch.nn.Module):
def __init__(self, emb_dim):
super(nodeEncoder, self).__init__()
self.atom_embedding_list = torch.nn.ModuleList()
self.node_feature_dim = get_node_feature_dims()
for i, dim in enumerate(self.node_feature_dim):
emb = torch.nn.Linear(dim, emb_dim)
torch.nn.init.xavier_uniform_(emb.weight.data)
self.atom_embedding_list.append(emb)
def forward(self, x):
x_embedding = 0
feature_dim_count = 0
for i in range(len(self.node_feature_dim)):
x_embedding += self.atom_embedding_list[i](
x[:, feature_dim_count:feature_dim_count + self.node_feature_dim[i]])
feature_dim_count += self.node_feature_dim[i]
return x_embedding
class edgeEncoder(torch.nn.Module):
def __init__(self, emb_dim):
super(edgeEncoder, self).__init__()
self.atom_embedding_list = torch.nn.ModuleList()
self.edge_feature_dims = get_edge_feature_dims()
for i, dim in enumerate(self.edge_feature_dims):
emb = torch.nn.Linear(dim, emb_dim)
torch.nn.init.xavier_uniform_(emb.weight.data)
self.atom_embedding_list.append(emb)
def forward(self, x):
x_embedding = 0
feature_dim_count = 0
for i in range(len(self.edge_feature_dims)):
x_embedding += self.atom_embedding_list[i](
x[:, feature_dim_count:feature_dim_count + self.edge_feature_dims[i]])
feature_dim_count += self.edge_feature_dims[i]
return x_embedding
class GNNClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, hidden_size, hidden_dropout_prob):
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.out_proj = nn.Linear(hidden_size, 1)
def forward(self, features, batch):
features = features.reshape(max(batch)+1, -1, features.shape[-1])
x = torch.mean(features, dim=1) # average pool over the tokens
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class ActiveSiteHead(nn.Module):
def __init__(self, input_dim, output_dim, dropout):
super(ActiveSiteHead, self).__init__()
dims = [4**i for i in range(1, 7)]
lin_dims = [output_dim] + [x for x in dims if output_dim < x < input_dim][1:-1] + [input_dim]
layers = []
for in_dim in lin_dims[::-1][:-1]:
layers.append(nn.Linear(in_dim, lin_dims[lin_dims.index(in_dim) - 1]))
layers.append(nn.Dropout(dropout))
layers.append(nn.SiLU())
layers.pop(); layers.pop()
self.dense = nn.Sequential(*layers)
def forward(self, x):
x = self.dense(x)
return x
class EGNN(nn.Module):
def __init__(self, config):
super(EGNN, self).__init__()
self.config = config
self.gnn_config = config.egnn
self.esm_dim = get_emb_dim(config.model.esm_version)
# self.input_dim = self.esm_dim + config.dataset.property_dim
self.input_dim = config.dataset.property_dim
self.mpnn_layes = nn.ModuleList([
EGNN_Sparse(
self.input_dim,
m_dim=int(self.gnn_config["hidden_channels"]),
edge_attr_dim=int(self.gnn_config["edge_attr_dim"]),
dropout=int(self.gnn_config["dropout"]),
mlp_num=int(self.gnn_config["mlp_num"]))
for _ in range(int(self.gnn_config["n_layers"]))])
if self.gnn_config["embedding"]:
self.node_embedding = nodeEncoder(self.input_dim)
self.edge_embedding = edgeEncoder(self.input_dim)
self.pred_head = ActiveSiteHead(self.input_dim, self.gnn_config['output_dim'], self.gnn_config['dropout'])
# self.lin = nn.Linear(input_dim, self.gnn_config['output_dim'])
# self.droplayer = nn.Dropout(int(self.gnn_config["dropout"]))
def forward(self, data):
x, pos, edge_index, edge_attr, batch, esm_rep, prop = (
data.x, data.pos,
data.edge_index,
data.edge_attr, data.batch,
data.esm_rep, data.prop
)
# 把prop中的第35列和第56列(表示氨基酸类型的one-hot向量)去掉
if self.config.dataset.property_dim == 41:
prop = torch.cat([prop[:,:35], prop[:,56:]], dim=1)
input_x = torch.cat([pos, prop], dim=1)
# input_x = torch.cat([pos, esm_rep, prop], dim=1)
# input_x = torch.cat([pos, input_x], dim=1)
if self.gnn_config['embedding']:
input_x = self.node_embedding(input_x)
edge_attr = self.edge_embedding(edge_attr)
for i, layer in enumerate(self.mpnn_layes):
h = layer(input_x, edge_index, edge_attr, batch)
if self.gnn_config['residual']:
input_x = input_x + h
else:
input_x = h
x = input_x[:, 3:]
x = self.pred_head(x)
# x = self.droplayer(x)
# x = self.lin(x)
# return x, input_x[:, 3:]
return x