libokj's picture
Upload 110 files
c0ec7e6
raw
history blame
3.96 kB
import torch
import torch.nn as nn
# TODO this is an easy model; refactor it to be customized by config file only
class HyperAttentionDTI(nn.Module):
def __init__(
self,
protein_kernel=(4, 8, 12),
drug_kernel=(4, 6, 8),
conv=40,
char_dim=64,
protein_max_len=1000,
drug_max_len=100
):
super().__init__()
self.drug_embed = nn.Embedding(63, char_dim, padding_idx=0)
self.drug_cnn = nn.Sequential(
nn.Conv1d(in_channels=char_dim, out_channels=conv, kernel_size=drug_kernel[0]),
nn.ReLU(),
nn.Conv1d(in_channels=conv, out_channels=conv * 2, kernel_size=drug_kernel[1]),
nn.ReLU(),
nn.Conv1d(in_channels=conv * 2, out_channels=conv * 4, kernel_size=drug_kernel[2]),
nn.ReLU(),
)
self.drug_max_pool = nn.MaxPool1d(
drug_max_len - drug_kernel[0] - drug_kernel[1] - drug_kernel[2] + 3)
self.protein_embed = nn.Embedding(26, char_dim, padding_idx=0)
self.protein_cnn = nn.Sequential(
nn.Conv1d(in_channels=char_dim, out_channels=conv, kernel_size=protein_kernel[0]),
nn.ReLU(),
nn.Conv1d(in_channels=conv, out_channels=conv * 2, kernel_size=protein_kernel[1]),
nn.ReLU(),
nn.Conv1d(in_channels=conv * 2, out_channels=conv * 4, kernel_size=protein_kernel[2]),
nn.ReLU(),
)
self.protein_max_pool = nn.MaxPool1d(
protein_max_len - protein_kernel[0] - protein_kernel[1] - protein_kernel[2] + 3)
self.attention_layer = nn.Linear(conv * 4, conv * 4)
self.protein_attention_layer = nn.Linear(conv * 4, conv * 4)
self.drug_attention_layer = nn.Linear(conv * 4, conv * 4)
self.dropout1 = nn.Dropout(0.1)
self.dropout2 = nn.Dropout(0.1)
self.dropout3 = nn.Dropout(0.1)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
self.sigmoid = nn.Sigmoid()
self.leaky_relu = nn.LeakyReLU()
self.fc1 = nn.Linear(conv * 8, 1024)
self.fc2 = nn.Linear(1024, 1024)
self.fc3 = nn.Linear(1024, 512)
# self.out = nn.Linear(512, 1)
def forward(self, drug, protein):
drugembed = self.drug_embed(drug.long())
proteinembed = self.protein_embed(protein.long())
drugembed = drugembed.permute(0, 2, 1)
proteinembed = proteinembed.permute(0, 2, 1)
drug_conv = self.drug_cnn(drugembed)
protein_conv = self.protein_cnn(proteinembed)
drug_att = self.drug_attention_layer(drug_conv.permute(0, 2, 1))
protein_att = self.protein_attention_layer(protein_conv.permute(0, 2, 1))
d_att_layers = torch.unsqueeze(drug_att, 2).repeat(1, 1, protein_conv.shape[-1], 1) # repeat along protein size
p_att_layers = torch.unsqueeze(protein_att, 1).repeat(1, drug_conv.shape[-1], 1, 1) # repeat along drug size
atten_matrix = self.attention_layer(self.relu(d_att_layers + p_att_layers))
compound_atte = torch.mean(atten_matrix, 2)
protein_atte = torch.mean(atten_matrix, 1)
compound_atte = self.sigmoid(compound_atte.permute(0, 2, 1))
protein_atte = self.sigmoid(protein_atte.permute(0, 2, 1))
drug_conv = drug_conv * 0.5 + drug_conv * compound_atte
protein_conv = protein_conv * 0.5 + protein_conv * protein_atte
drug_conv = self.drug_max_pool(drug_conv).squeeze(2)
protein_conv = self.protein_max_pool(protein_conv).squeeze(2)
preds = torch.cat([drug_conv, protein_conv], dim=1)
preds = self.dropout1(preds)
preds = self.leaky_relu(self.fc1(preds))
preds = self.dropout2(preds)
preds = self.leaky_relu(self.fc2(preds))
preds = self.dropout3(preds)
preds = self.leaky_relu(self.fc3(preds))
# preds = self.out(preds)
return preds