libokj's picture
Upload 110 files
c0ec7e6
raw
history blame
3.82 kB
import torch
import torch.nn as nn
class DeepConvDTI(nn.Module):
def __init__(self, dropout=0.2, drug_layers=(1024, 512), protein_windows=(10, 15, 20, 25), n_filters=64,
decay=0.0, fc_layers=None, convolution=True, activation=nn.ReLU(), protein_layers=None):
super().__init__()
self.dropout = dropout
self.drug_layers = drug_layers
self.protein_windows = protein_windows
self.filters = n_filters
self.decay = decay
self.fc_layers = fc_layers
self.convolution = convolution
self.activation = activation # Use any nn.Module as the activation function
self.protein_layers = protein_layers
# Define the drug branch of the model
self.drug_branch = []
for layer_size in drug_layers:
self.drug_branch += [
nn.LazyLinear(layer_size),
nn.BatchNorm1d(layer_size),
activation,
nn.Dropout(dropout)
]
self.drug_branch = nn.Sequential(*self.drug_branch)
# Define the protein branch of the model
if convolution:
# Use embedding and convolution layers for protein sequences
self.protein_embedding = nn.Embedding(26, 20)
# Use a list of parallel convolution and pooling layers with different window sizes
self.protein_convs = nn.ModuleList()
for window_size in protein_windows:
conv = nn.Sequential(
nn.Conv1d(20, n_filters, window_size, padding="same"),
nn.BatchNorm1d(n_filters),
activation,
nn.AdaptiveMaxPool1d(1)
)
self.protein_convs.append(conv)
if protein_layers:
self.protein_branch = []
for layer_size in protein_layers:
self.protein_branch += [
nn.LazyLinear(layer_size),
nn.BatchNorm1d(layer_size),
activation,
nn.Dropout(dropout)
]
self.protein_branch = nn.Sequential(*self.protein_branch)
# Define the final branch of the model that combines the drug and protein branches
self.final_branch = []
if fc_layers:
# Add additional dense layers for the final branch
for layer_size in fc_layers:
self.final_branch += [
nn.LazyLinear(layer_size),
nn.BatchNorm1d(layer_size),
activation
]
self.final_branch = nn.Sequential(*self.final_branch)
def forward(self, input_d, input_p):
# Forward pass of the drug branch
output_d = self.drug_branch(input_d.float())
# Forward pass of the protein branch
if self.convolution:
# Embed the protein sequence and transpose the dimensions
output_p = self.protein_embedding(input_p)
output_p = output_p.transpose(1, 2)
# Apply the parallel convolution and pooling layers
conv_outputs = []
for conv in self.protein_convs:
conv_output = conv(output_p).squeeze(-1)
conv_outputs.append(conv_output)
# Concatenate the convolution outputs
output_p = torch.cat(conv_outputs, dim=1)
else:
output_p = input_p
if self.protein_layers:
# Apply the additional dense layers to the protein branch
output_p = self.protein_branch(output_p)
# Concatenate the drug and protein outputs
output_t = torch.cat([output_d, output_p], dim=1)
# Apply the final dense layers
output_t = self.final_branch(output_t)
return output_t