import torch import torch.nn as nn from model.egnn import EGNN_Sparse from model.egnn.utils import get_edge_feature_dims, get_node_feature_dims from utils.util_functions import get_emb_dim class nodeEncoder(torch.nn.Module): def __init__(self, emb_dim): super(nodeEncoder, self).__init__() self.atom_embedding_list = torch.nn.ModuleList() self.node_feature_dim = get_node_feature_dims() for i, dim in enumerate(self.node_feature_dim): emb = torch.nn.Linear(dim, emb_dim) torch.nn.init.xavier_uniform_(emb.weight.data) self.atom_embedding_list.append(emb) def forward(self, x): x_embedding = 0 feature_dim_count = 0 for i in range(len(self.node_feature_dim)): x_embedding += self.atom_embedding_list[i]( x[:, feature_dim_count:feature_dim_count + self.node_feature_dim[i]]) feature_dim_count += self.node_feature_dim[i] return x_embedding class edgeEncoder(torch.nn.Module): def __init__(self, emb_dim): super(edgeEncoder, self).__init__() self.atom_embedding_list = torch.nn.ModuleList() self.edge_feature_dims = get_edge_feature_dims() for i, dim in enumerate(self.edge_feature_dims): emb = torch.nn.Linear(dim, emb_dim) torch.nn.init.xavier_uniform_(emb.weight.data) self.atom_embedding_list.append(emb) def forward(self, x): x_embedding = 0 feature_dim_count = 0 for i in range(len(self.edge_feature_dims)): x_embedding += self.atom_embedding_list[i]( x[:, feature_dim_count:feature_dim_count + self.edge_feature_dims[i]]) feature_dim_count += self.edge_feature_dims[i] return x_embedding class GNNClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" def __init__(self, hidden_size, hidden_dropout_prob): super().__init__() self.dense = nn.Linear(hidden_size, hidden_size) self.dropout = nn.Dropout(hidden_dropout_prob) self.out_proj = nn.Linear(hidden_size, 1) def forward(self, features, batch): features = features.reshape(max(batch)+1, -1, features.shape[-1]) x = torch.mean(features, dim=1) # average pool over the tokens x = self.dropout(x) x = self.dense(x) x = torch.tanh(x) x = self.dropout(x) x = self.out_proj(x) return x class ActiveSiteHead(nn.Module): def __init__(self, input_dim, output_dim, dropout): super(ActiveSiteHead, self).__init__() dims = [4**i for i in range(1, 7)] lin_dims = [output_dim] + [x for x in dims if output_dim < x < input_dim][1:-1] + [input_dim] layers = [] for in_dim in lin_dims[::-1][:-1]: layers.append(nn.Linear(in_dim, lin_dims[lin_dims.index(in_dim) - 1])) layers.append(nn.Dropout(dropout)) layers.append(nn.SiLU()) layers.pop(); layers.pop() self.dense = nn.Sequential(*layers) def forward(self, x): x = self.dense(x) return x class EGNN(nn.Module): def __init__(self, config): super(EGNN, self).__init__() self.config = config self.gnn_config = config.egnn self.esm_dim = get_emb_dim(config.model.esm_version) # self.input_dim = self.esm_dim + config.dataset.property_dim self.input_dim = config.dataset.property_dim self.mpnn_layes = nn.ModuleList([ EGNN_Sparse( self.input_dim, m_dim=int(self.gnn_config["hidden_channels"]), edge_attr_dim=int(self.gnn_config["edge_attr_dim"]), dropout=int(self.gnn_config["dropout"]), mlp_num=int(self.gnn_config["mlp_num"])) for _ in range(int(self.gnn_config["n_layers"]))]) if self.gnn_config["embedding"]: self.node_embedding = nodeEncoder(self.input_dim) self.edge_embedding = edgeEncoder(self.input_dim) self.pred_head = ActiveSiteHead(self.input_dim, self.gnn_config['output_dim'], self.gnn_config['dropout']) # self.lin = nn.Linear(input_dim, self.gnn_config['output_dim']) # self.droplayer = nn.Dropout(int(self.gnn_config["dropout"])) def forward(self, data): x, pos, edge_index, edge_attr, batch, esm_rep, prop = ( data.x, data.pos, data.edge_index, data.edge_attr, data.batch, data.esm_rep, data.prop ) # 把prop中的第35列和第56列(表示氨基酸类型的one-hot向量)去掉 if self.config.dataset.property_dim == 41: prop = torch.cat([prop[:,:35], prop[:,56:]], dim=1) input_x = torch.cat([pos, prop], dim=1) # input_x = torch.cat([pos, esm_rep, prop], dim=1) # input_x = torch.cat([pos, input_x], dim=1) if self.gnn_config['embedding']: input_x = self.node_embedding(input_x) edge_attr = self.edge_embedding(edge_attr) for i, layer in enumerate(self.mpnn_layes): h = layer(input_x, edge_index, edge_attr, batch) if self.gnn_config['residual']: input_x = input_x + h else: input_x = h x = input_x[:, 3:] x = self.pred_head(x) # x = self.droplayer(x) # x = self.lin(x) # return x, input_x[:, 3:] return x