Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from UltraFlow import layers, losses | |
class IGN_basic(nn.Module): | |
def __init__(self,config): | |
super(IGN_basic, self).__init__() | |
self.config = config | |
self.pretrain_assay_mlp_share = config.train.pretrain_assay_mlp_share | |
self.pretrain_use_assay_description = config.train.pretrain_use_assay_description | |
self.graph_conv = layers.ModifiedAttentiveFPGNNV2(config.model.lig_node_dim, config.model.lig_edge_dim, config.model.num_layers, config.model.hidden_dim, config.model.dropout, config.model.jk) | |
if config.model.jk == 'concat': | |
self.noncov_graph = layers.DTIConvGraph3Layer_IGN_basic(config.model.hidden_dim * config.model.num_layers + config.model.inter_edge_dim, config.model.inter_out_dim, config.model.dropout) | |
else: | |
self.noncov_graph = layers.DTIConvGraph3Layer_IGN_basic(config.model.hidden_dim + config.model.inter_edge_dim, config.model.inter_out_dim, config.model.dropout) | |
if config.model.readout.startswith('multi_head') and config.model.attn_merge=='concat': | |
self.FC = layers.FC(config.model.inter_out_dim * (config.model.num_head + 1), config.model.fc_hidden_dim, config.model.dropout, config.model.out_dim) | |
else: | |
self.FC = layers.FC(config.model.inter_out_dim * 2, config.model.fc_hidden_dim, config.model.dropout, config.model.out_dim) | |
self.readout = layers.ReadsOutLayer(config.model.inter_out_dim, config.model.readout, config.model.num_head, config.model.attn_merge) | |
self.softmax = nn.Softmax(dim=1) | |
if self.pretrain_use_assay_description: | |
print(f'use assay descrption type: {config.data.assay_des_type}') | |
if self.pretrain_assay_mlp_share: | |
self.assay_info_aggre_mlp = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
config.model.dropout, config.model.inter_out_dim * 2) | |
else: | |
self.assay_info_aggre_mlp_pointwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
config.model.dropout, config.model.inter_out_dim * 2) | |
self.assay_info_aggre_mlp_pairwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
config.model.dropout, config.model.inter_out_dim * 2) | |
def forward(self, batch): | |
bg_lig, bg_prot, bg_inter, labels, _, ass_des = batch | |
node_feats_lig = self.graph_conv(bg_lig) | |
node_feats_prot = self.graph_conv(bg_prot) | |
bg_inter.ndata['h'] = self.alignfeature(bg_lig,bg_prot,node_feats_lig,node_feats_prot) | |
bond_feats_inter = self.noncov_graph(bg_inter) | |
graph_embedding = self.readout(bg_inter, bond_feats_inter) | |
if self.pretrain_use_assay_description: | |
if self.pretrain_assay_mlp_share: | |
ranking_assay_embedding = self.assay_info_aggre_mlp(ass_des) | |
affinity_pred = self.FC(graph_embedding + ranking_assay_embedding) | |
else: | |
regression_assay_embedding = self.assay_info_aggre_mlp_pointwise(ass_des) | |
affinity_pred = self.FC(graph_embedding + regression_assay_embedding) | |
ranking_assay_embedding = self.assay_info_aggre_mlp_pairwise(ass_des) | |
else: | |
affinity_pred = self.FC(graph_embedding) | |
ranking_assay_embedding = torch.zeros(len(affinity_pred)) | |
return affinity_pred, graph_embedding, ranking_assay_embedding | |
def alignfeature(self,bg_lig,bg_prot,node_feats_lig,node_feats_prot): | |
inter_feature = torch.cat((node_feats_lig,node_feats_prot)) | |
lig_num,prot_num = bg_lig.batch_num_nodes(),bg_prot.batch_num_nodes() | |
lig_start, prot_start = lig_num.cumsum(0) - lig_num, prot_num.cumsum(0) - prot_num | |
inter_start = lig_start + prot_start | |
for i in range(lig_num.shape[0]): | |
inter_feature[inter_start[i]:inter_start[i]+lig_num[i]] = node_feats_lig[lig_start[i]:lig_start[i]+lig_num[i]] | |
inter_feature[inter_start[i]+lig_num[i]:inter_start[i]+lig_num[i]+prot_num[i]] = node_feats_prot[prot_start[i]:prot_start[i]+prot_num[i]] | |
return inter_feature | |
class IGN(nn.Module): | |
def __init__(self,config): | |
super(IGN, self).__init__() | |
self.config = config | |
self.pretrain_assay_mlp_share = config.train.pretrain_assay_mlp_share | |
self.pretrain_use_assay_description = config.train.pretrain_use_assay_description | |
self.ligand_conv = layers.ModifiedAttentiveFPGNNV2(config.model.lig_node_dim, config.model.lig_edge_dim, config.model.num_layers, config.model.hidden_dim, config.model.dropout, config.model.jk) | |
self.protein_conv = layers.ModifiedAttentiveFPGNNV2(config.model.pro_node_dim, config.model.pro_edge_dim, config.model.num_layers, config.model.hidden_dim, config.model.dropout, config.model.jk) | |
if config.model.jk == 'concat': | |
self.noncov_graph = layers.DTIConvGraph3Layer(config.model.hidden_dim * (config.model.num_layers + config.model.num_layers) + config.model.inter_edge_dim, config.model.inter_out_dim, config.model.dropout) | |
else: | |
self.noncov_graph = layers.DTIConvGraph3Layer(config.model.hidden_dim * 2 + config.model.inter_edge_dim, config.model.inter_out_dim, config.model.dropout) | |
if config.model.readout.startswith('multi_head') and config.model.attn_merge=='concat': | |
self.FC = layers.FC(config.model.inter_out_dim * (config.model.num_head + 1), config.model.fc_hidden_dim, config.model.dropout, config.model.out_dim) | |
else: | |
self.FC = layers.FC(config.model.inter_out_dim * 2, config.model.fc_hidden_dim, config.model.dropout, config.model.out_dim) | |
self.readout = layers.ReadsOutLayer(config.model.inter_out_dim, config.model.readout, config.model.num_head, config.model.attn_merge) | |
self.softmax = nn.Softmax(dim=1) | |
if self.pretrain_use_assay_description: | |
print(f'use assay descrption type: {config.data.assay_des_type}') | |
if self.pretrain_assay_mlp_share: | |
self.assay_info_aggre_mlp = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
config.model.dropout, config.model.inter_out_dim * 2) | |
else: | |
self.assay_info_aggre_mlp_pointwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
config.model.dropout, config.model.inter_out_dim * 2) | |
self.assay_info_aggre_mlp_pairwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
config.model.dropout, config.model.inter_out_dim * 2) | |
def forward(self, batch): | |
bg_lig, bg_prot, bg_inter, labels, _, ass_des = batch | |
node_feats_lig = self.ligand_conv(bg_lig) | |
node_feats_prot = self.protein_conv(bg_prot) | |
bg_inter.ndata['h'] = self.alignfeature(bg_lig,bg_prot,node_feats_lig,node_feats_prot) | |
bond_feats_inter = self.noncov_graph(bg_inter) | |
graph_embedding = self.readout(bg_inter, bond_feats_inter) | |
if self.pretrain_use_assay_description: | |
if self.pretrain_assay_mlp_share: | |
ranking_assay_embedding = self.assay_info_aggre_mlp(ass_des) | |
affinity_pred = self.FC(graph_embedding + ranking_assay_embedding) | |
else: | |
regression_assay_embedding = self.assay_info_aggre_mlp_pointwise(ass_des) | |
affinity_pred = self.FC(graph_embedding + regression_assay_embedding) | |
ranking_assay_embedding = self.assay_info_aggre_mlp_pairwise(ass_des) | |
else: | |
affinity_pred = self.FC(graph_embedding) | |
ranking_assay_embedding = torch.zeros(len(affinity_pred)) | |
return affinity_pred, graph_embedding, ranking_assay_embedding | |
def alignfeature(self,bg_lig,bg_prot,node_feats_lig,node_feats_prot): | |
inter_feature = torch.cat((node_feats_lig,node_feats_prot)) | |
lig_num,prot_num = bg_lig.batch_num_nodes(),bg_prot.batch_num_nodes() | |
lig_start, prot_start = lig_num.cumsum(0) - lig_num, prot_num.cumsum(0) - prot_num | |
inter_start = lig_start + prot_start | |
for i in range(lig_num.shape[0]): | |
inter_feature[inter_start[i]:inter_start[i]+lig_num[i]] = node_feats_lig[lig_start[i]:lig_start[i]+lig_num[i]] | |
inter_feature[inter_start[i]+lig_num[i]:inter_start[i]+lig_num[i]+prot_num[i]] = node_feats_prot[prot_start[i]:prot_start[i]+prot_num[i]] | |
return inter_feature | |
class GNNs(nn.Module): | |
def __init__(self, nLigNode, nLigEdge, nLayer, nHid, JK, GNN): | |
super(GNNs, self).__init__() | |
if GNN == 'GCN': | |
self.Encoder = layers.GCN(nLigNode, hidden_feats=[nHid] * nLayer) | |
elif GNN == 'GAT': | |
self.Encoder = layers.GAT(nLigNode, hidden_feats=[nHid] * nLayer) | |
elif GNN == 'GIN': | |
self.Encoder = layers.GIN(nLigNode, nHid, nLayer, num_mlp_layers=2, dropout=0.1, learn_eps=False, | |
neighbor_pooling_type='sum', JK=JK) | |
elif GNN == 'EGNN': | |
self.Encoder = layers.EGNN(nLigNode, nLigEdge, nHid, nLayer, dropout=0.1, JK=JK) | |
elif GNN == 'AttentiveFP': | |
self.Encoder = layers.ModifiedAttentiveFPGNNV2(nLigNode, nLigEdge, nLayer, nHid, 0.1, JK) | |
def forward(self, Graph, Perturb=None): | |
Node_Rep = self.Encoder(Graph, Perturb) | |
return Node_Rep | |
class Affinity_GNNs(nn.Module): | |
def __init__(self, config): | |
super(Affinity_GNNs, self).__init__() | |
lig_node_dim = config.model.lig_node_dim | |
lig_edge_dim = config.model.lig_edge_dim | |
pro_node_dim = config.model.pro_node_dim | |
pro_edge_dim = config.model.pro_edge_dim | |
layer_num = config.model.num_layers | |
hidden_dim = config.model.hidden_dim | |
jk = config.model.jk | |
GNN = config.model.GNN_type | |
self.pretrain_assay_mlp_share = config.train.pretrain_assay_mlp_share | |
self.pretrain_use_assay_description = config.train.pretrain_use_assay_description | |
self.lig_encoder = GNNs(lig_node_dim, lig_edge_dim, layer_num, hidden_dim, jk, GNN) | |
self.pro_encoder = GNNs(pro_node_dim, pro_edge_dim, layer_num, hidden_dim, jk, GNN) | |
if config.model.jk == 'concat': | |
self.noncov_graph = layers.DTIConvGraph3Layer(hidden_dim * (layer_num + layer_num) + config.model.inter_edge_dim, config.model.inter_out_dim, config.model.dropout) | |
else: | |
self.noncov_graph = layers.DTIConvGraph3Layer(hidden_dim * 2 + config.model.inter_edge_dim, config.model.inter_out_dim, config.model.dropout) | |
self.readout = layers.ReadsOutLayer(config.model.inter_out_dim, config.model.readout, config.model.num_head, config.model.attn_merge) | |
if config.model.readout.startswith('multi_head') and config.model.attn_merge=='concat': | |
self.FC = layers.FC(config.model.inter_out_dim * (config.model.num_head + 1), config.model.fc_hidden_dim, config.model.dropout, config.model.out_dim) | |
else: | |
self.FC = layers.FC(config.model.inter_out_dim * 2, config.model.fc_hidden_dim, config.model.dropout, config.model.out_dim) | |
self.softmax = nn.Softmax(dim=1) | |
if self.pretrain_use_assay_description: | |
print(f'use assay descrption type: {config.data.assay_des_type}') | |
if self.pretrain_assay_mlp_share: | |
self.assay_info_aggre_mlp = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
config.model.dropout, config.model.inter_out_dim * 2) | |
else: | |
self.assay_info_aggre_mlp_pointwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
config.model.dropout, config.model.inter_out_dim * 2) | |
self.assay_info_aggre_mlp_pairwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
config.model.dropout, config.model.inter_out_dim * 2) | |
def forward(self, batch): | |
bg_lig, bg_prot, bg_inter, labels, _, ass_des = batch | |
node_feats_lig = self.lig_encoder(bg_lig) | |
node_feats_prot = self.pro_encoder(bg_prot) | |
bg_inter.ndata['h'] = self.alignfeature(bg_lig,bg_prot,node_feats_lig,node_feats_prot) | |
bond_feats_inter = self.noncov_graph(bg_inter) | |
graph_embedding = self.readout(bg_inter, bond_feats_inter) | |
if self.pretrain_use_assay_description: | |
if self.pretrain_assay_mlp_share: | |
ranking_assay_embedding = self.assay_info_aggre_mlp(ass_des) | |
affinity_pred = self.FC(graph_embedding + ranking_assay_embedding) | |
else: | |
regression_assay_embedding = self.assay_info_aggre_mlp_pointwise(ass_des) | |
affinity_pred = self.FC(graph_embedding + regression_assay_embedding) | |
ranking_assay_embedding = self.assay_info_aggre_mlp_pairwise(ass_des) | |
else: | |
affinity_pred = self.FC(graph_embedding) | |
ranking_assay_embedding = torch.zeros(len(affinity_pred)) | |
return affinity_pred, graph_embedding, ranking_assay_embedding | |
def alignfeature(self,bg_lig,bg_prot,node_feats_lig,node_feats_prot): | |
inter_feature = torch.cat((node_feats_lig,node_feats_prot)) | |
lig_num,prot_num = bg_lig.batch_num_nodes(),bg_prot.batch_num_nodes() | |
lig_start, prot_start = lig_num.cumsum(0) - lig_num, prot_num.cumsum(0) - prot_num | |
inter_start = lig_start + prot_start | |
for i in range(lig_num.shape[0]): | |
inter_feature[inter_start[i]:inter_start[i]+lig_num[i]] = node_feats_lig[lig_start[i]:lig_start[i]+lig_num[i]] | |
inter_feature[inter_start[i]+lig_num[i]:inter_start[i]+lig_num[i]+prot_num[i]] = node_feats_prot[prot_start[i]:prot_start[i]+prot_num[i]] | |
return inter_feature | |
class affinity_head(nn.Module): | |
def __init__(self, config): | |
super(affinity_head, self).__init__() | |
self.pretrain_assay_mlp_share = config.train.pretrain_assay_mlp_share | |
self.pretrain_use_assay_description = config.train.pretrain_use_assay_description | |
if self.pretrain_use_assay_description: | |
print(f'use assay descrption type: {config.data.assay_des_type}') | |
if self.pretrain_assay_mlp_share: | |
self.assay_info_aggre_mlp = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
config.model.dropout, config.model.inter_out_dim * 2) | |
else: | |
self.assay_info_aggre_mlp_pointwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
config.model.dropout, config.model.inter_out_dim * 2) | |
self.assay_info_aggre_mlp_pairwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
config.model.dropout, config.model.inter_out_dim * 2) | |
if config.model.readout.startswith('multi_head') and config.model.attn_merge=='concat': | |
self.FC = layers.FC(config.model.inter_out_dim * (config.model.num_head + 1), config.model.fintune_fc_hidden_dim, config.model.dropout, config.model.out_dim) | |
else: | |
self.FC = layers.FC(config.model.inter_out_dim * 2, config.model.fintune_fc_hidden_dim, config.model.dropout, config.model.out_dim) | |
def forward(self, graph_embedding, ass_des): | |
if self.pretrain_use_assay_description: | |
if self.pretrain_assay_mlp_share: | |
ranking_assay_embedding = self.assay_info_aggre_mlp(ass_des) | |
affinity_pred = self.FC(graph_embedding + ranking_assay_embedding) | |
else: | |
regression_assay_embedding = self.assay_info_aggre_mlp_pointwise(ass_des) | |
affinity_pred = self.FC(graph_embedding + regression_assay_embedding) | |
ranking_assay_embedding = self.assay_info_aggre_mlp_pairwise(ass_des) | |
else: | |
affinity_pred = self.FC(graph_embedding) | |
ranking_assay_embedding = torch.zeros(len(affinity_pred)) | |
return affinity_pred | |
class ASRP_head(nn.Module): | |
def __init__(self, config): | |
super(ASRP_head, self).__init__() | |
self.readout = layers.ReadsOutLayer(config.model.inter_out_dim, config.model.readout, config.model.num_head, config.model.attn_merge) | |
self.pretrain_assay_mlp_share = config.train.pretrain_assay_mlp_share | |
self.pretrain_use_assay_description = config.train.pretrain_use_assay_description | |
if self.pretrain_use_assay_description: | |
print(f'use assay descrption type: {config.data.assay_des_type}') | |
if self.pretrain_assay_mlp_share: | |
self.assay_info_aggre_mlp = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
config.model.dropout, config.model.inter_out_dim * 2) | |
else: | |
self.assay_info_aggre_mlp_pointwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
config.model.dropout, config.model.inter_out_dim * 2) | |
self.assay_info_aggre_mlp_pairwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, | |
config.model.dropout, config.model.inter_out_dim * 2) | |
if config.model.readout.startswith('multi_head') and config.model.attn_merge=='concat': | |
self.FC = layers.FC(config.model.inter_out_dim * (config.model.num_head + 1), config.model.fintune_fc_hidden_dim, config.model.dropout, config.model.out_dim) | |
else: | |
self.FC = layers.FC(config.model.inter_out_dim * 2, config.model.fintune_fc_hidden_dim, config.model.dropout, config.model.out_dim) | |
self.regression_loss_fn = nn.MSELoss(reduce=False) | |
self.ranking_loss_fn = losses.pairwise_BCE_loss(config) | |
self.pairwise_two_tower_regression_loss = config.train.pairwise_two_tower_regression_loss | |
if self.pairwise_two_tower_regression_loss: | |
print('use two tower regression loss') | |
def forward(self, bg_inter, bond_feats_inter, ass_des, labels, select_flag): | |
graph_embedding = self.readout(bg_inter, bond_feats_inter) | |
if self.pretrain_use_assay_description: | |
if self.pretrain_assay_mlp_share: | |
ranking_assay_embedding = self.assay_info_aggre_mlp(ass_des) | |
affinity_pred = self.FC(graph_embedding + ranking_assay_embedding) | |
else: | |
regression_assay_embedding = self.assay_info_aggre_mlp_pointwise(ass_des) | |
affinity_pred = self.FC(graph_embedding + regression_assay_embedding) | |
ranking_assay_embedding = self.assay_info_aggre_mlp_pairwise(ass_des) | |
else: | |
affinity_pred = self.FC(graph_embedding) | |
ranking_assay_embedding = torch.zeros(len(affinity_pred)) | |
y_pred_num = len(affinity_pred) | |
assert y_pred_num % 2 == 0 | |
if self.pairwise_two_tower_regression_loss: | |
regression_loss = self.regression_loss_fn(affinity_pred, labels) # | |
labels_select = labels[select_flag] | |
affinity_pred_select = affinity_pred[select_flag] | |
regression_loss_select = regression_loss[select_flag].sum() | |
else: | |
regression_loss = self.regression_loss_fn(affinity_pred[:y_pred_num // 2], labels[:y_pred_num // 2]) # | |
labels_select = labels[:y_pred_num // 2][select_flag[:y_pred_num // 2]] | |
affinity_pred_select = affinity_pred[:y_pred_num // 2][select_flag[:y_pred_num // 2]] | |
regression_loss_select = regression_loss[select_flag[:y_pred_num // 2]].sum() | |
ranking_loss, relation, relation_pred = self.ranking_loss_fn(graph_embedding, labels, ranking_assay_embedding) # | |
ranking_loss_select = ranking_loss[select_flag[:y_pred_num // 2]].sum() | |
relation_select = relation[select_flag[:y_pred_num // 2]] | |
relation_pred_selcet = relation_pred[select_flag[:y_pred_num // 2]] | |
return regression_loss_select, ranking_loss_select,\ | |
labels_select, affinity_pred_select,\ | |
relation_select, relation_pred_selcet | |
def forward_pointwise(self, bg_inter, bond_feats_inter, ass_des, labels, select_flag): | |
graph_embedding = self.readout(bg_inter, bond_feats_inter) | |
affinity_pred = self.FC(graph_embedding) | |
regression_loss = self.regression_loss_fn(affinity_pred, labels) # | |
regression_loss_select = regression_loss[select_flag].sum() | |
labels_select = labels[select_flag] | |
affinity_pred_select = affinity_pred[select_flag] | |
return regression_loss_select, labels_select, affinity_pred_select | |
def evaluate_mtl(self, bg_inter, bond_feats_inter, ass_des, labels): | |
graph_embedding = self.readout(bg_inter, bond_feats_inter) | |
if self.pretrain_use_assay_description: | |
if self.pretrain_assay_mlp_share: | |
ranking_assay_embedding = self.assay_info_aggre_mlp(ass_des) | |
affinity_pred = self.FC(graph_embedding + ranking_assay_embedding) | |
else: | |
regression_assay_embedding = self.assay_info_aggre_mlp_pointwise(ass_des) | |
affinity_pred = self.FC(graph_embedding + regression_assay_embedding) | |
ranking_assay_embedding = self.assay_info_aggre_mlp_pairwise(ass_des) | |
else: | |
affinity_pred = self.FC(graph_embedding) | |
ranking_assay_embedding = torch.zeros(len(affinity_pred)) | |
n = graph_embedding.shape[0] | |
pair_a_index, pair_b_index = [], [] | |
for i in range(n): | |
pair_a_index.extend([i] * (n - 1)) | |
pair_b_index.extend([j for j in range(n) if i != j]) | |
pair_index = pair_a_index + pair_b_index | |
_, relation, relation_pred = self.ranking_fn(graph_embedding[pair_index], labels[pair_index], ranking_assay_embedding[pair_index]) | |
return affinity_pred, relation, relation_pred | |
class Affinity_GNNs_MTL(nn.Module): | |
def __init__(self, config): | |
super(Affinity_GNNs_MTL, self).__init__() | |
lig_node_dim = config.model.lig_node_dim | |
lig_edge_dim = config.model.lig_edge_dim | |
pro_node_dim = config.model.pro_node_dim | |
pro_edge_dim = config.model.pro_edge_dim | |
layer_num = config.model.num_layers | |
hidden_dim = config.model.hidden_dim | |
jk = config.model.jk | |
GNN = config.model.GNN_type | |
self.multi_task = config.train.multi_task | |
self.pretrain_assay_mlp_share = config.train.pretrain_assay_mlp_share | |
self.pretrain_use_assay_description = config.train.pretrain_use_assay_description | |
self.lig_encoder = GNNs(lig_node_dim, lig_edge_dim, layer_num, hidden_dim, jk, GNN) | |
self.pro_encoder = GNNs(pro_node_dim, pro_edge_dim, layer_num, hidden_dim, jk, GNN) | |
if config.model.jk == 'concat': | |
self.noncov_graph = layers.DTIConvGraph3Layer(hidden_dim * (layer_num + layer_num) + config.model.inter_edge_dim, config.model.inter_out_dim, config.model.dropout) | |
else: | |
self.noncov_graph = layers.DTIConvGraph3Layer(hidden_dim * 2 + config.model.inter_edge_dim, config.model.inter_out_dim, config.model.dropout) | |
self.softmax = nn.Softmax(dim=1) | |
if self.multi_task == 'IC50KdKi': | |
self.IC50_ASRP_head = ASRP_head(config) | |
self.Kd_ASRP_head = ASRP_head(config) | |
self.Ki_ASRP_head = ASRP_head(config) | |
elif self.multi_task == 'IC50K': | |
self.IC50_ASRP_head = ASRP_head(config) | |
self.K_ASRP_head = ASRP_head(config) | |
self.config = config | |
def forward(self, batch, ASRP=True, Perturb=None, Perturb_v=None): | |
if self.multi_task == 'IC50KdKi': | |
bg_lig, bg_prot, bg_inter, labels, _, ass_des, IC50_f, Kd_f, Ki_f = batch | |
lig_node_feats_init = bg_lig.ndata['h'] | |
pro_node_feats_init = bg_prot.ndata['h'] | |
if Perturb is not None and Perturb_v == 'v_intra': | |
node_feats_lig = self.lig_encoder(bg_lig, Perturb_v[:bg_lig.number_of_nodes()]) | |
node_feats_prot = self.pro_encoder(bg_prot, Perturb_v[bg_lig.number_of_nodes():]) | |
else: | |
node_feats_lig = self.lig_encoder(bg_lig) | |
node_feats_prot = self.pro_encoder(bg_prot) | |
if self.config.train.encoder_ablation == 'interact': | |
return node_feats_lig, node_feats_prot | |
elif self.config.train.encoder_ablation == 'ligand': | |
node_feats_lig = node_feats_lig.zero_() | |
node_feats_lig[:,:self.config.model.lig_node_dim] = lig_node_feats_init | |
elif self.config.train.encoder_ablation == 'protein': | |
node_feats_prot = node_feats_prot.zero_() | |
node_feats_prot[:,:self.config.model.pro_node_dim] = pro_node_feats_init | |
bg_inter.ndata['h'] = self.alignfeature(bg_lig,bg_prot,node_feats_lig,node_feats_prot) | |
if Perturb is not None and Perturb_v == 'v_inter': | |
bg_inter.ndata['h'] = bg_inter.ndata['h'] + Perturb | |
bond_feats_inter = self.noncov_graph(bg_inter) | |
if ASRP: | |
return self.multi_head_pred(bg_inter, bond_feats_inter, labels, ass_des, IC50_f, Kd_f, Ki_f) | |
else: | |
return self.multi_head_pointwise(bg_inter, bond_feats_inter, labels, ass_des, IC50_f, Kd_f, Ki_f) | |
elif self.multi_task == 'IC50K': | |
bg_lig, bg_prot, bg_inter, labels, _, ass_des, IC50_f, K_f = batch | |
lig_node_feats_init = bg_lig.ndata['h'] | |
pro_node_feats_init = bg_prot.ndata['h'] | |
if Perturb is not None and Perturb_v == 'v_intra': | |
node_feats_lig = self.lig_encoder(bg_lig, Perturb_v[:bg_lig.number_of_nodes()]) | |
node_feats_prot = self.pro_encoder(bg_prot, Perturb_v[bg_lig.number_of_nodes():]) | |
else: | |
node_feats_lig = self.lig_encoder(bg_lig) | |
node_feats_prot = self.pro_encoder(bg_prot) | |
if self.config.train.encoder_ablation == 'interact': | |
return node_feats_lig, node_feats_prot | |
elif self.config.train.encoder_ablation == 'ligand': | |
node_feats_lig = node_feats_lig.zero_() | |
node_feats_lig[:,:self.config.model.lig_node_dim] = lig_node_feats_init | |
elif self.config.train.encoder_ablation == 'protein': | |
node_feats_prot = node_feats_prot.zero_() | |
node_feats_prot[:,:self.config.model.pro_node_dim] = pro_node_feats_init | |
bg_inter.ndata['h'] = self.alignfeature(bg_lig,bg_prot,node_feats_lig,node_feats_prot) | |
if Perturb is not None and Perturb_v == 'v_inter': | |
bg_inter.ndata['h'] = bg_inter.ndata['h'] + Perturb | |
bond_feats_inter = self.noncov_graph(bg_inter) | |
if ASRP: | |
return self.multi_head_pred_v2(bg_inter, bond_feats_inter, labels, ass_des, IC50_f, K_f) | |
else: | |
return self.multi_head_pointwise_v2(bg_inter, bond_feats_inter, labels, ass_des, IC50_f, K_f) | |
def multi_head_pointwise(self, bg_inter, bond_feats_inter, labels, ass_des, IC50_f, Kd_f, Ki_f): | |
regression_loss_IC50, affinity_IC50, affinity_pred_IC50 = \ | |
self.IC50_ASRP_head.forward_pointwise(bg_inter, bond_feats_inter, ass_des, labels, IC50_f) | |
regression_loss_Kd, affinity_Kd, affinity_pred_Kd = \ | |
self.Kd_ASRP_head.forward_pointwise(bg_inter, bond_feats_inter, ass_des, labels, Kd_f) | |
regression_loss_Ki, affinity_Ki, affinity_pred_Ki = \ | |
self.Ki_ASRP_head.forward_pointwise(bg_inter, bond_feats_inter, ass_des, labels, Ki_f) | |
return (regression_loss_IC50, regression_loss_Kd, regression_loss_Ki),\ | |
(affinity_pred_IC50, affinity_pred_Kd, affinity_pred_Ki), \ | |
(affinity_IC50, affinity_Kd, affinity_Ki) | |
def multi_head_pointwise_v2(self, bg_inter, bond_feats_inter, labels, ass_des, IC50_f, K_f): | |
regression_loss_IC50, affinity_IC50, affinity_pred_IC50 = \ | |
self.IC50_ASRP_head.forward_pointwise(bg_inter, bond_feats_inter, ass_des, labels, IC50_f) | |
regression_loss_K, affinity_K, affinity_pred_K = \ | |
self.K_ASRP_head.forward_pointwise(bg_inter, bond_feats_inter, ass_des, labels, K_f) | |
return (regression_loss_IC50, regression_loss_K),\ | |
(affinity_pred_IC50, affinity_pred_K), \ | |
(affinity_IC50, affinity_K) | |
def multi_head_pred(self, bg_inter, bond_feats_inter, labels, ass_des, IC50_f, Kd_f, Ki_f): | |
regression_loss_IC50, ranking_loss_IC50, \ | |
affinity_IC50, affinity_pred_IC50, \ | |
relation_IC50, relation_pred_IC50 = self.IC50_ASRP_head(bg_inter, bond_feats_inter, ass_des, labels, IC50_f) | |
regression_loss_Kd, ranking_loss_Kd, \ | |
affinity_Kd, affinity_pred_Kd, \ | |
relation_Kd, relation_pred_Kd = self.Kd_ASRP_head(bg_inter, bond_feats_inter, ass_des, labels, Kd_f) | |
regression_loss_Ki, ranking_loss_Ki, \ | |
affinity_Ki, affinity_pred_Ki, \ | |
relation_Ki, relation_pred_Ki = self.Ki_ASRP_head(bg_inter, bond_feats_inter, ass_des, labels, Ki_f) | |
return (regression_loss_IC50, regression_loss_Kd, regression_loss_Ki),\ | |
(ranking_loss_IC50, ranking_loss_Kd, ranking_loss_Ki), \ | |
(affinity_pred_IC50, affinity_pred_Kd, affinity_pred_Ki), \ | |
(relation_pred_IC50, relation_pred_Kd, relation_pred_Ki), \ | |
(affinity_IC50, affinity_Kd, affinity_Ki), \ | |
(relation_IC50, relation_Kd, relation_Kd) | |
def multi_head_pred_v2(self, bg_inter, bond_feats_inter, labels, ass_des, IC50_f, K_f): | |
regression_loss_IC50, ranking_loss_IC50, \ | |
affinity_IC50, affinity_pred_IC50, \ | |
relation_IC50, relation_pred_IC50 = self.IC50_ASRP_head(bg_inter, bond_feats_inter, ass_des, labels, IC50_f) | |
regression_loss_K, ranking_loss_K, \ | |
affinity_K, affinity_pred_K, \ | |
relation_K, relation_pred_K = self.K_ASRP_head(bg_inter, bond_feats_inter, ass_des, labels, K_f) | |
return (regression_loss_IC50, regression_loss_K),\ | |
(ranking_loss_IC50, ranking_loss_K), \ | |
(affinity_pred_IC50, affinity_pred_K), \ | |
(relation_pred_IC50, relation_pred_K), \ | |
(affinity_IC50, affinity_K), \ | |
(relation_IC50, relation_K) | |
def multi_head_evaluate(self, bg_inter, bond_feats_inter, labels, ass_des, IC50_f, Kd_f, Ki_f): | |
if sum(IC50_f): | |
assert sum(Kd_f) == 0 and sum(Ki_f) == 0 | |
return self.IC50_ASRP_head.evaluate_mtl(bg_inter, bond_feats_inter, labels, ass_des) | |
elif sum(Kd_f): | |
assert sum(IC50_f) == 0 and sum(Ki_f) == 0 | |
return self.Kd_ASRP_head.evaluate_mtl(bg_inter, bond_feats_inter, labels, ass_des) | |
elif sum(Ki_f): | |
assert sum(IC50_f) == 0 and sum(Kd_f) == 0 | |
return self.Kd_ASRP_head.evaluate_mtl(bg_inter, bond_feats_inter, labels, ass_des) | |
def alignfeature(self,bg_lig,bg_prot,node_feats_lig,node_feats_prot): | |
inter_feature = torch.cat((node_feats_lig,node_feats_prot)) | |
lig_num,prot_num = bg_lig.batch_num_nodes(),bg_prot.batch_num_nodes() | |
lig_start, prot_start = lig_num.cumsum(0) - lig_num, prot_num.cumsum(0) - prot_num | |
inter_start = lig_start + prot_start | |
for i in range(lig_num.shape[0]): | |
inter_feature[inter_start[i]:inter_start[i]+lig_num[i]] = node_feats_lig[lig_start[i]:lig_start[i]+lig_num[i]] | |
inter_feature[inter_start[i]+lig_num[i]:inter_start[i]+lig_num[i]+prot_num[i]] = node_feats_prot[prot_start[i]:prot_start[i]+prot_num[i]] | |
return inter_feature | |
class interact_ablation(nn.Module): | |
def __init__(self, config): | |
super(interact_ablation, self).__init__() | |
self.IC50_ASRP_head = interact_ablation_head(config) | |
self.K_ASRP_head = interact_ablation_head(config) | |
self.config = config | |
def forward(self, graph_embedding, labels, IC50_f, K_f): | |
regression_loss_IC50, \ | |
affinity_IC50, affinity_pred_IC50,= self.IC50_ASRP_head(graph_embedding, labels, IC50_f) | |
regression_loss_K, \ | |
affinity_K, affinity_pred_K = self.K_ASRP_head(graph_embedding, labels, K_f) | |
return (regression_loss_IC50, regression_loss_K),\ | |
(affinity_pred_IC50, affinity_pred_K), \ | |
(affinity_IC50, affinity_K), \ | |
class interact_ablation_head(nn.Module): | |
def __init__(self, config): | |
super(interact_ablation_head, self).__init__() | |
self.FC = layers.FC(config.model.inter_out_dim * 2, config.model.fintune_fc_hidden_dim, config.model.dropout, | |
config.model.out_dim) | |
self.regression_loss_fn = nn.MSELoss(reduce=False) | |
def forward(self, graph_embedding, labels, select_flag): | |
affinity_pred = self.FC(graph_embedding) | |
regression_loss = self.regression_loss_fn(affinity_pred, labels) # | |
regression_loss_select = regression_loss[select_flag].sum() | |
labels_select = labels[select_flag] | |
affinity_pred_select = affinity_pred[select_flag] | |
return regression_loss_select, labels_select, affinity_pred_select |