File size: 3,824 Bytes
c0ec7e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.nn as nn


class DeepConvDTI(nn.Module):
    def __init__(self, dropout=0.2, drug_layers=(1024, 512), protein_windows=(10, 15, 20, 25), n_filters=64,
                 decay=0.0, fc_layers=None, convolution=True, activation=nn.ReLU(), protein_layers=None):
        super().__init__()
        self.dropout = dropout
        self.drug_layers = drug_layers
        self.protein_windows = protein_windows
        self.filters = n_filters
        self.decay = decay
        self.fc_layers = fc_layers
        self.convolution = convolution
        self.activation = activation  # Use any nn.Module as the activation function
        self.protein_layers = protein_layers

        # Define the drug branch of the model
        self.drug_branch = []
        for layer_size in drug_layers:
            self.drug_branch += [
                nn.LazyLinear(layer_size),
                nn.BatchNorm1d(layer_size),
                activation,
                nn.Dropout(dropout)
            ]
        self.drug_branch = nn.Sequential(*self.drug_branch)

        # Define the protein branch of the model
        if convolution:
            # Use embedding and convolution layers for protein sequences
            self.protein_embedding = nn.Embedding(26, 20)
            # Use a list of parallel convolution and pooling layers with different window sizes
            self.protein_convs = nn.ModuleList()
            for window_size in protein_windows:
                conv = nn.Sequential(
                    nn.Conv1d(20, n_filters, window_size, padding="same"),
                    nn.BatchNorm1d(n_filters),
                    activation,
                    nn.AdaptiveMaxPool1d(1)
                )
                self.protein_convs.append(conv)

        if protein_layers:
            self.protein_branch = []
            for layer_size in protein_layers:
                self.protein_branch += [
                    nn.LazyLinear(layer_size),
                    nn.BatchNorm1d(layer_size),
                    activation,
                    nn.Dropout(dropout)
                ]
            self.protein_branch = nn.Sequential(*self.protein_branch)

        # Define the final branch of the model that combines the drug and protein branches
        self.final_branch = []
        if fc_layers:
            # Add additional dense layers for the final branch
            for layer_size in fc_layers:
                self.final_branch += [
                    nn.LazyLinear(layer_size),
                    nn.BatchNorm1d(layer_size),
                    activation
                ]
            self.final_branch = nn.Sequential(*self.final_branch)

    def forward(self, input_d, input_p):
        # Forward pass of the drug branch
        output_d = self.drug_branch(input_d.float())

        # Forward pass of the protein branch
        if self.convolution:
            # Embed the protein sequence and transpose the dimensions
            output_p = self.protein_embedding(input_p)
            output_p = output_p.transpose(1, 2)
            # Apply the parallel convolution and pooling layers
            conv_outputs = []
            for conv in self.protein_convs:
                conv_output = conv(output_p).squeeze(-1)
                conv_outputs.append(conv_output)
            # Concatenate the convolution outputs
            output_p = torch.cat(conv_outputs, dim=1)
        else:
            output_p = input_p

        if self.protein_layers:
            # Apply the additional dense layers to the protein branch
            output_p = self.protein_branch(output_p)

        # Concatenate the drug and protein outputs
        output_t = torch.cat([output_d, output_p], dim=1)
        # Apply the final dense layers
        output_t = self.final_branch(output_t)
        return output_t