Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import torch | |
import torch.nn as nn | |
# TODO this is an easy model; refactor it to be customized by config file only | |
class HyperAttentionDTI(nn.Module): | |
def __init__( | |
self, | |
protein_kernel=(4, 8, 12), | |
drug_kernel=(4, 6, 8), | |
conv=40, | |
char_dim=64, | |
protein_max_len=1000, | |
drug_max_len=100 | |
): | |
super().__init__() | |
self.drug_embed = nn.Embedding(63, char_dim, padding_idx=0) | |
self.drug_cnn = nn.Sequential( | |
nn.Conv1d(in_channels=char_dim, out_channels=conv, kernel_size=drug_kernel[0]), | |
nn.ReLU(), | |
nn.Conv1d(in_channels=conv, out_channels=conv * 2, kernel_size=drug_kernel[1]), | |
nn.ReLU(), | |
nn.Conv1d(in_channels=conv * 2, out_channels=conv * 4, kernel_size=drug_kernel[2]), | |
nn.ReLU(), | |
) | |
self.drug_max_pool = nn.MaxPool1d( | |
drug_max_len - drug_kernel[0] - drug_kernel[1] - drug_kernel[2] + 3) | |
self.protein_embed = nn.Embedding(26, char_dim, padding_idx=0) | |
self.protein_cnn = nn.Sequential( | |
nn.Conv1d(in_channels=char_dim, out_channels=conv, kernel_size=protein_kernel[0]), | |
nn.ReLU(), | |
nn.Conv1d(in_channels=conv, out_channels=conv * 2, kernel_size=protein_kernel[1]), | |
nn.ReLU(), | |
nn.Conv1d(in_channels=conv * 2, out_channels=conv * 4, kernel_size=protein_kernel[2]), | |
nn.ReLU(), | |
) | |
self.protein_max_pool = nn.MaxPool1d( | |
protein_max_len - protein_kernel[0] - protein_kernel[1] - protein_kernel[2] + 3) | |
self.attention_layer = nn.Linear(conv * 4, conv * 4) | |
self.protein_attention_layer = nn.Linear(conv * 4, conv * 4) | |
self.drug_attention_layer = nn.Linear(conv * 4, conv * 4) | |
self.dropout1 = nn.Dropout(0.1) | |
self.dropout2 = nn.Dropout(0.1) | |
self.dropout3 = nn.Dropout(0.1) | |
self.relu = nn.ReLU() | |
self.tanh = nn.Tanh() | |
self.sigmoid = nn.Sigmoid() | |
self.leaky_relu = nn.LeakyReLU() | |
self.fc1 = nn.Linear(conv * 8, 1024) | |
self.fc2 = nn.Linear(1024, 1024) | |
self.fc3 = nn.Linear(1024, 512) | |
# self.out = nn.Linear(512, 1) | |
def forward(self, drug, protein): | |
drugembed = self.drug_embed(drug.long()) | |
proteinembed = self.protein_embed(protein.long()) | |
drugembed = drugembed.permute(0, 2, 1) | |
proteinembed = proteinembed.permute(0, 2, 1) | |
drug_conv = self.drug_cnn(drugembed) | |
protein_conv = self.protein_cnn(proteinembed) | |
drug_att = self.drug_attention_layer(drug_conv.permute(0, 2, 1)) | |
protein_att = self.protein_attention_layer(protein_conv.permute(0, 2, 1)) | |
d_att_layers = torch.unsqueeze(drug_att, 2).repeat(1, 1, protein_conv.shape[-1], 1) # repeat along protein size | |
p_att_layers = torch.unsqueeze(protein_att, 1).repeat(1, drug_conv.shape[-1], 1, 1) # repeat along drug size | |
atten_matrix = self.attention_layer(self.relu(d_att_layers + p_att_layers)) | |
compound_atte = torch.mean(atten_matrix, 2) | |
protein_atte = torch.mean(atten_matrix, 1) | |
compound_atte = self.sigmoid(compound_atte.permute(0, 2, 1)) | |
protein_atte = self.sigmoid(protein_atte.permute(0, 2, 1)) | |
drug_conv = drug_conv * 0.5 + drug_conv * compound_atte | |
protein_conv = protein_conv * 0.5 + protein_conv * protein_atte | |
drug_conv = self.drug_max_pool(drug_conv).squeeze(2) | |
protein_conv = self.protein_max_pool(protein_conv).squeeze(2) | |
preds = torch.cat([drug_conv, protein_conv], dim=1) | |
preds = self.dropout1(preds) | |
preds = self.leaky_relu(self.fc1(preds)) | |
preds = self.dropout2(preds) | |
preds = self.leaky_relu(self.fc2(preds)) | |
preds = self.dropout3(preds) | |
preds = self.leaky_relu(self.fc3(preds)) | |
# preds = self.out(preds) | |
return preds | |