Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import torch | |
import torch.nn as nn | |
# TODO this is an easy model; refactor it to be customized by config file only | |
class DeepDTA(nn.Module): | |
""" | |
From DeepDTA | |
""" | |
def __init__( | |
self, | |
drug_cnn: nn.Module, | |
protein_cnn: nn.Module, | |
num_features_drug: int, | |
num_features_protein: int, | |
embed_dim: int, | |
): | |
super().__init__() | |
self.drug_cnn = drug_cnn | |
self.protein_cnn = protein_cnn | |
self.fc = nn.Sequential(nn.LazyLinear(1024), nn.ReLU(), nn.Dropout(0.1), | |
nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.1)) | |
# protein sequence encoder (1d conv) | |
self.drug_embedding = nn.Embedding(num_features_drug, embed_dim) | |
self.protein_embedding = nn.Embedding(num_features_protein, embed_dim) | |
def forward(self, v_d, v_p): | |
v_d = self.drug_embedding(v_d.long()) | |
v_d = self.drug_cnn(v_d) | |
v_p = self.protein_embedding(v_p.long()) | |
v_p = self.protein_cnn(v_p) | |
v_f = torch.cat([v_d, v_p], 1) | |
v_f = self.fc(v_f) | |
return v_f | |