Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import torch | |
import torch.nn as nn | |
class DeepConvDTI(nn.Module): | |
def __init__(self, dropout=0.2, drug_layers=(1024, 512), protein_windows=(10, 15, 20, 25), n_filters=64, | |
decay=0.0, fc_layers=None, convolution=True, activation=nn.ReLU(), protein_layers=None): | |
super().__init__() | |
self.dropout = dropout | |
self.drug_layers = drug_layers | |
self.protein_windows = protein_windows | |
self.filters = n_filters | |
self.decay = decay | |
self.fc_layers = fc_layers | |
self.convolution = convolution | |
self.activation = activation # Use any nn.Module as the activation function | |
self.protein_layers = protein_layers | |
# Define the drug branch of the model | |
self.drug_branch = [] | |
for layer_size in drug_layers: | |
self.drug_branch += [ | |
nn.LazyLinear(layer_size), | |
nn.BatchNorm1d(layer_size), | |
activation, | |
nn.Dropout(dropout) | |
] | |
self.drug_branch = nn.Sequential(*self.drug_branch) | |
# Define the protein branch of the model | |
if convolution: | |
# Use embedding and convolution layers for protein sequences | |
self.protein_embedding = nn.Embedding(26, 20) | |
# Use a list of parallel convolution and pooling layers with different window sizes | |
self.protein_convs = nn.ModuleList() | |
for window_size in protein_windows: | |
conv = nn.Sequential( | |
nn.Conv1d(20, n_filters, window_size, padding="same"), | |
nn.BatchNorm1d(n_filters), | |
activation, | |
nn.AdaptiveMaxPool1d(1) | |
) | |
self.protein_convs.append(conv) | |
if protein_layers: | |
self.protein_branch = [] | |
for layer_size in protein_layers: | |
self.protein_branch += [ | |
nn.LazyLinear(layer_size), | |
nn.BatchNorm1d(layer_size), | |
activation, | |
nn.Dropout(dropout) | |
] | |
self.protein_branch = nn.Sequential(*self.protein_branch) | |
# Define the final branch of the model that combines the drug and protein branches | |
self.final_branch = [] | |
if fc_layers: | |
# Add additional dense layers for the final branch | |
for layer_size in fc_layers: | |
self.final_branch += [ | |
nn.LazyLinear(layer_size), | |
nn.BatchNorm1d(layer_size), | |
activation | |
] | |
self.final_branch = nn.Sequential(*self.final_branch) | |
def forward(self, input_d, input_p): | |
# Forward pass of the drug branch | |
output_d = self.drug_branch(input_d.float()) | |
# Forward pass of the protein branch | |
if self.convolution: | |
# Embed the protein sequence and transpose the dimensions | |
output_p = self.protein_embedding(input_p) | |
output_p = output_p.transpose(1, 2) | |
# Apply the parallel convolution and pooling layers | |
conv_outputs = [] | |
for conv in self.protein_convs: | |
conv_output = conv(output_p).squeeze(-1) | |
conv_outputs.append(conv_output) | |
# Concatenate the convolution outputs | |
output_p = torch.cat(conv_outputs, dim=1) | |
else: | |
output_p = input_p | |
if self.protein_layers: | |
# Apply the additional dense layers to the protein branch | |
output_p = self.protein_branch(output_p) | |
# Concatenate the drug and protein outputs | |
output_t = torch.cat([output_d, output_p], dim=1) | |
# Apply the final dense layers | |
output_t = self.final_branch(output_t) | |
return output_t | |