File size: 1,141 Bytes
c0ec7e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torch.nn as nn

# TODO this is an easy model; refactor it to be customized by config file only


class DeepDTA(nn.Module):
    """
    From DeepDTA
    """
    def __init__(
            self,
            drug_cnn: nn.Module,
            protein_cnn: nn.Module,
            num_features_drug: int,
            num_features_protein: int,
            embed_dim: int,
    ):
        super().__init__()
        self.drug_cnn = drug_cnn
        self.protein_cnn = protein_cnn
        self.fc = nn.Sequential(nn.LazyLinear(1024), nn.ReLU(), nn.Dropout(0.1),
                                nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.1))

        # protein sequence encoder (1d conv)
        self.drug_embedding = nn.Embedding(num_features_drug, embed_dim)
        self.protein_embedding = nn.Embedding(num_features_protein, embed_dim)
    
    def forward(self, v_d, v_p):
        v_d = self.drug_embedding(v_d.long())
        v_d = self.drug_cnn(v_d)

        v_p = self.protein_embedding(v_p.long())
        v_p = self.protein_cnn(v_p)

        v_f = torch.cat([v_d, v_p], 1)
        v_f = self.fc(v_f)

        return v_f