import torch import torch.nn as nn # TODO this is an easy model; refactor it to be customized by config file only class DeepDTA(nn.Module): """ From DeepDTA """ def __init__( self, drug_cnn: nn.Module, protein_cnn: nn.Module, num_features_drug: int, num_features_protein: int, embed_dim: int, ): super().__init__() self.drug_cnn = drug_cnn self.protein_cnn = protein_cnn self.fc = nn.Sequential(nn.LazyLinear(1024), nn.ReLU(), nn.Dropout(0.1), nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.1)) # protein sequence encoder (1d conv) self.drug_embedding = nn.Embedding(num_features_drug, embed_dim) self.protein_embedding = nn.Embedding(num_features_protein, embed_dim) def forward(self, v_d, v_p): v_d = self.drug_embedding(v_d.long()) v_d = self.drug_cnn(v_d) v_p = self.protein_embedding(v_p.long()) v_p = self.protein_cnn(v_p) v_f = torch.cat([v_d, v_p], 1) v_f = self.fc(v_f) return v_f