Spaces:
Running
Running
File size: 5,514 Bytes
224a33f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import torch
import torch.nn as nn
from model.egnn import EGNN_Sparse
from model.egnn.utils import get_edge_feature_dims, get_node_feature_dims
from utils.util_functions import get_emb_dim
class nodeEncoder(torch.nn.Module):
def __init__(self, emb_dim):
super(nodeEncoder, self).__init__()
self.atom_embedding_list = torch.nn.ModuleList()
self.node_feature_dim = get_node_feature_dims()
for i, dim in enumerate(self.node_feature_dim):
emb = torch.nn.Linear(dim, emb_dim)
torch.nn.init.xavier_uniform_(emb.weight.data)
self.atom_embedding_list.append(emb)
def forward(self, x):
x_embedding = 0
feature_dim_count = 0
for i in range(len(self.node_feature_dim)):
x_embedding += self.atom_embedding_list[i](
x[:, feature_dim_count:feature_dim_count + self.node_feature_dim[i]])
feature_dim_count += self.node_feature_dim[i]
return x_embedding
class edgeEncoder(torch.nn.Module):
def __init__(self, emb_dim):
super(edgeEncoder, self).__init__()
self.atom_embedding_list = torch.nn.ModuleList()
self.edge_feature_dims = get_edge_feature_dims()
for i, dim in enumerate(self.edge_feature_dims):
emb = torch.nn.Linear(dim, emb_dim)
torch.nn.init.xavier_uniform_(emb.weight.data)
self.atom_embedding_list.append(emb)
def forward(self, x):
x_embedding = 0
feature_dim_count = 0
for i in range(len(self.edge_feature_dims)):
x_embedding += self.atom_embedding_list[i](
x[:, feature_dim_count:feature_dim_count + self.edge_feature_dims[i]])
feature_dim_count += self.edge_feature_dims[i]
return x_embedding
class GNNClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, hidden_size, hidden_dropout_prob):
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.out_proj = nn.Linear(hidden_size, 1)
def forward(self, features, batch):
features = features.reshape(max(batch)+1, -1, features.shape[-1])
x = torch.mean(features, dim=1) # average pool over the tokens
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class ActiveSiteHead(nn.Module):
def __init__(self, input_dim, output_dim, dropout):
super(ActiveSiteHead, self).__init__()
dims = [4**i for i in range(1, 7)]
lin_dims = [output_dim] + [x for x in dims if output_dim < x < input_dim][1:-1] + [input_dim]
layers = []
for in_dim in lin_dims[::-1][:-1]:
layers.append(nn.Linear(in_dim, lin_dims[lin_dims.index(in_dim) - 1]))
layers.append(nn.Dropout(dropout))
layers.append(nn.SiLU())
layers.pop(); layers.pop()
self.dense = nn.Sequential(*layers)
def forward(self, x):
x = self.dense(x)
return x
class EGNN(nn.Module):
def __init__(self, config):
super(EGNN, self).__init__()
self.config = config
self.gnn_config = config.egnn
self.esm_dim = get_emb_dim(config.model.esm_version)
# self.input_dim = self.esm_dim + config.dataset.property_dim
self.input_dim = config.dataset.property_dim
self.mpnn_layes = nn.ModuleList([
EGNN_Sparse(
self.input_dim,
m_dim=int(self.gnn_config["hidden_channels"]),
edge_attr_dim=int(self.gnn_config["edge_attr_dim"]),
dropout=int(self.gnn_config["dropout"]),
mlp_num=int(self.gnn_config["mlp_num"]))
for _ in range(int(self.gnn_config["n_layers"]))])
if self.gnn_config["embedding"]:
self.node_embedding = nodeEncoder(self.input_dim)
self.edge_embedding = edgeEncoder(self.input_dim)
self.pred_head = ActiveSiteHead(self.input_dim, self.gnn_config['output_dim'], self.gnn_config['dropout'])
# self.lin = nn.Linear(input_dim, self.gnn_config['output_dim'])
# self.droplayer = nn.Dropout(int(self.gnn_config["dropout"]))
def forward(self, data):
x, pos, edge_index, edge_attr, batch, esm_rep, prop = (
data.x, data.pos,
data.edge_index,
data.edge_attr, data.batch,
data.esm_rep, data.prop
)
# 把prop中的第35列和第56列(表示氨基酸类型的one-hot向量)去掉
if self.config.dataset.property_dim == 41:
prop = torch.cat([prop[:,:35], prop[:,56:]], dim=1)
input_x = torch.cat([pos, prop], dim=1)
# input_x = torch.cat([pos, esm_rep, prop], dim=1)
# input_x = torch.cat([pos, input_x], dim=1)
if self.gnn_config['embedding']:
input_x = self.node_embedding(input_x)
edge_attr = self.edge_embedding(edge_attr)
for i, layer in enumerate(self.mpnn_layes):
h = layer(input_x, edge_index, edge_attr, batch)
if self.gnn_config['residual']:
input_x = input_x + h
else:
input_x = h
x = input_x[:, 3:]
x = self.pred_head(x)
# x = self.droplayer(x)
# x = self.lin(x)
# return x, input_x[:, 3:]
return x
|