import torch import torch.nn as nn # TODO this is an easy model; refactor it to be customized by config file only class HyperAttentionDTI(nn.Module): def __init__( self, protein_kernel=(4, 8, 12), drug_kernel=(4, 6, 8), conv=40, char_dim=64, protein_max_len=1000, drug_max_len=100 ): super().__init__() self.drug_embed = nn.Embedding(63, char_dim, padding_idx=0) self.drug_cnn = nn.Sequential( nn.Conv1d(in_channels=char_dim, out_channels=conv, kernel_size=drug_kernel[0]), nn.ReLU(), nn.Conv1d(in_channels=conv, out_channels=conv * 2, kernel_size=drug_kernel[1]), nn.ReLU(), nn.Conv1d(in_channels=conv * 2, out_channels=conv * 4, kernel_size=drug_kernel[2]), nn.ReLU(), ) self.drug_max_pool = nn.MaxPool1d( drug_max_len - drug_kernel[0] - drug_kernel[1] - drug_kernel[2] + 3) self.protein_embed = nn.Embedding(26, char_dim, padding_idx=0) self.protein_cnn = nn.Sequential( nn.Conv1d(in_channels=char_dim, out_channels=conv, kernel_size=protein_kernel[0]), nn.ReLU(), nn.Conv1d(in_channels=conv, out_channels=conv * 2, kernel_size=protein_kernel[1]), nn.ReLU(), nn.Conv1d(in_channels=conv * 2, out_channels=conv * 4, kernel_size=protein_kernel[2]), nn.ReLU(), ) self.protein_max_pool = nn.MaxPool1d( protein_max_len - protein_kernel[0] - protein_kernel[1] - protein_kernel[2] + 3) self.attention_layer = nn.Linear(conv * 4, conv * 4) self.protein_attention_layer = nn.Linear(conv * 4, conv * 4) self.drug_attention_layer = nn.Linear(conv * 4, conv * 4) self.dropout1 = nn.Dropout(0.1) self.dropout2 = nn.Dropout(0.1) self.dropout3 = nn.Dropout(0.1) self.relu = nn.ReLU() self.tanh = nn.Tanh() self.sigmoid = nn.Sigmoid() self.leaky_relu = nn.LeakyReLU() self.fc1 = nn.Linear(conv * 8, 1024) self.fc2 = nn.Linear(1024, 1024) self.fc3 = nn.Linear(1024, 512) # self.out = nn.Linear(512, 1) def forward(self, drug, protein): drugembed = self.drug_embed(drug.long()) proteinembed = self.protein_embed(protein.long()) drugembed = drugembed.permute(0, 2, 1) proteinembed = proteinembed.permute(0, 2, 1) drug_conv = self.drug_cnn(drugembed) protein_conv = self.protein_cnn(proteinembed) drug_att = self.drug_attention_layer(drug_conv.permute(0, 2, 1)) protein_att = self.protein_attention_layer(protein_conv.permute(0, 2, 1)) d_att_layers = torch.unsqueeze(drug_att, 2).repeat(1, 1, protein_conv.shape[-1], 1) # repeat along protein size p_att_layers = torch.unsqueeze(protein_att, 1).repeat(1, drug_conv.shape[-1], 1, 1) # repeat along drug size atten_matrix = self.attention_layer(self.relu(d_att_layers + p_att_layers)) compound_atte = torch.mean(atten_matrix, 2) protein_atte = torch.mean(atten_matrix, 1) compound_atte = self.sigmoid(compound_atte.permute(0, 2, 1)) protein_atte = self.sigmoid(protein_atte.permute(0, 2, 1)) drug_conv = drug_conv * 0.5 + drug_conv * compound_atte protein_conv = protein_conv * 0.5 + protein_conv * protein_atte drug_conv = self.drug_max_pool(drug_conv).squeeze(2) protein_conv = self.protein_max_pool(protein_conv).squeeze(2) preds = torch.cat([drug_conv, protein_conv], dim=1) preds = self.dropout1(preds) preds = self.leaky_relu(self.fc1(preds)) preds = self.dropout2(preds) preds = self.leaky_relu(self.fc2(preds)) preds = self.dropout3(preds) preds = self.leaky_relu(self.fc3(preds)) # preds = self.out(preds) return preds