Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 1,141 Bytes
c0ec7e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import torch
import torch.nn as nn
# TODO this is an easy model; refactor it to be customized by config file only
class DeepDTA(nn.Module):
"""
From DeepDTA
"""
def __init__(
self,
drug_cnn: nn.Module,
protein_cnn: nn.Module,
num_features_drug: int,
num_features_protein: int,
embed_dim: int,
):
super().__init__()
self.drug_cnn = drug_cnn
self.protein_cnn = protein_cnn
self.fc = nn.Sequential(nn.LazyLinear(1024), nn.ReLU(), nn.Dropout(0.1),
nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.1))
# protein sequence encoder (1d conv)
self.drug_embedding = nn.Embedding(num_features_drug, embed_dim)
self.protein_embedding = nn.Embedding(num_features_protein, embed_dim)
def forward(self, v_d, v_p):
v_d = self.drug_embedding(v_d.long())
v_d = self.drug_cnn(v_d)
v_p = self.protein_embedding(v_p.long())
v_p = self.protein_cnn(v_p)
v_f = torch.cat([v_d, v_p], 1)
v_f = self.fc(v_f)
return v_f
|