Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from model.egnn import EGNN_Sparse | |
from model.egnn.utils import get_edge_feature_dims, get_node_feature_dims | |
from utils.util_functions import get_emb_dim | |
class nodeEncoder(torch.nn.Module): | |
def __init__(self, emb_dim): | |
super(nodeEncoder, self).__init__() | |
self.atom_embedding_list = torch.nn.ModuleList() | |
self.node_feature_dim = get_node_feature_dims() | |
for i, dim in enumerate(self.node_feature_dim): | |
emb = torch.nn.Linear(dim, emb_dim) | |
torch.nn.init.xavier_uniform_(emb.weight.data) | |
self.atom_embedding_list.append(emb) | |
def forward(self, x): | |
x_embedding = 0 | |
feature_dim_count = 0 | |
for i in range(len(self.node_feature_dim)): | |
x_embedding += self.atom_embedding_list[i]( | |
x[:, feature_dim_count:feature_dim_count + self.node_feature_dim[i]]) | |
feature_dim_count += self.node_feature_dim[i] | |
return x_embedding | |
class edgeEncoder(torch.nn.Module): | |
def __init__(self, emb_dim): | |
super(edgeEncoder, self).__init__() | |
self.atom_embedding_list = torch.nn.ModuleList() | |
self.edge_feature_dims = get_edge_feature_dims() | |
for i, dim in enumerate(self.edge_feature_dims): | |
emb = torch.nn.Linear(dim, emb_dim) | |
torch.nn.init.xavier_uniform_(emb.weight.data) | |
self.atom_embedding_list.append(emb) | |
def forward(self, x): | |
x_embedding = 0 | |
feature_dim_count = 0 | |
for i in range(len(self.edge_feature_dims)): | |
x_embedding += self.atom_embedding_list[i]( | |
x[:, feature_dim_count:feature_dim_count + self.edge_feature_dims[i]]) | |
feature_dim_count += self.edge_feature_dims[i] | |
return x_embedding | |
class GNNClassificationHead(nn.Module): | |
"""Head for sentence-level classification tasks.""" | |
def __init__(self, hidden_size, hidden_dropout_prob): | |
super().__init__() | |
self.dense = nn.Linear(hidden_size, hidden_size) | |
self.dropout = nn.Dropout(hidden_dropout_prob) | |
self.out_proj = nn.Linear(hidden_size, 1) | |
def forward(self, features, batch): | |
features = features.reshape(max(batch)+1, -1, features.shape[-1]) | |
x = torch.mean(features, dim=1) # average pool over the tokens | |
x = self.dropout(x) | |
x = self.dense(x) | |
x = torch.tanh(x) | |
x = self.dropout(x) | |
x = self.out_proj(x) | |
return x | |
class ActiveSiteHead(nn.Module): | |
def __init__(self, input_dim, output_dim, dropout): | |
super(ActiveSiteHead, self).__init__() | |
dims = [4**i for i in range(1, 7)] | |
lin_dims = [output_dim] + [x for x in dims if output_dim < x < input_dim][1:-1] + [input_dim] | |
layers = [] | |
for in_dim in lin_dims[::-1][:-1]: | |
layers.append(nn.Linear(in_dim, lin_dims[lin_dims.index(in_dim) - 1])) | |
layers.append(nn.Dropout(dropout)) | |
layers.append(nn.SiLU()) | |
layers.pop(); layers.pop() | |
self.dense = nn.Sequential(*layers) | |
def forward(self, x): | |
x = self.dense(x) | |
return x | |
class EGNN(nn.Module): | |
def __init__(self, config): | |
super(EGNN, self).__init__() | |
self.config = config | |
self.gnn_config = config.egnn | |
self.esm_dim = get_emb_dim(config.model.esm_version) | |
# self.input_dim = self.esm_dim + config.dataset.property_dim | |
self.input_dim = config.dataset.property_dim | |
self.mpnn_layes = nn.ModuleList([ | |
EGNN_Sparse( | |
self.input_dim, | |
m_dim=int(self.gnn_config["hidden_channels"]), | |
edge_attr_dim=int(self.gnn_config["edge_attr_dim"]), | |
dropout=int(self.gnn_config["dropout"]), | |
mlp_num=int(self.gnn_config["mlp_num"])) | |
for _ in range(int(self.gnn_config["n_layers"]))]) | |
if self.gnn_config["embedding"]: | |
self.node_embedding = nodeEncoder(self.input_dim) | |
self.edge_embedding = edgeEncoder(self.input_dim) | |
self.pred_head = ActiveSiteHead(self.input_dim, self.gnn_config['output_dim'], self.gnn_config['dropout']) | |
# self.lin = nn.Linear(input_dim, self.gnn_config['output_dim']) | |
# self.droplayer = nn.Dropout(int(self.gnn_config["dropout"])) | |
def forward(self, data): | |
x, pos, edge_index, edge_attr, batch, esm_rep, prop = ( | |
data.x, data.pos, | |
data.edge_index, | |
data.edge_attr, data.batch, | |
data.esm_rep, data.prop | |
) | |
# 把prop中的第35列和第56列(表示氨基酸类型的one-hot向量)去掉 | |
if self.config.dataset.property_dim == 41: | |
prop = torch.cat([prop[:,:35], prop[:,56:]], dim=1) | |
input_x = torch.cat([pos, prop], dim=1) | |
# input_x = torch.cat([pos, esm_rep, prop], dim=1) | |
# input_x = torch.cat([pos, input_x], dim=1) | |
if self.gnn_config['embedding']: | |
input_x = self.node_embedding(input_x) | |
edge_attr = self.edge_embedding(edge_attr) | |
for i, layer in enumerate(self.mpnn_layes): | |
h = layer(input_x, edge_index, edge_attr, batch) | |
if self.gnn_config['residual']: | |
input_x = input_x + h | |
else: | |
input_x = h | |
x = input_x[:, 3:] | |
x = self.pred_head(x) | |
# x = self.droplayer(x) | |
# x = self.lin(x) | |
# return x, input_x[:, 3:] | |
return x | |