import torch import torch.nn as nn class DeepConvDTI(nn.Module): def __init__(self, dropout=0.2, drug_layers=(1024, 512), protein_windows=(10, 15, 20, 25), n_filters=64, decay=0.0, fc_layers=None, convolution=True, activation=nn.ReLU(), protein_layers=None): super().__init__() self.dropout = dropout self.drug_layers = drug_layers self.protein_windows = protein_windows self.filters = n_filters self.decay = decay self.fc_layers = fc_layers self.convolution = convolution self.activation = activation # Use any nn.Module as the activation function self.protein_layers = protein_layers # Define the drug branch of the model self.drug_branch = [] for layer_size in drug_layers: self.drug_branch += [ nn.LazyLinear(layer_size), nn.BatchNorm1d(layer_size), activation, nn.Dropout(dropout) ] self.drug_branch = nn.Sequential(*self.drug_branch) # Define the protein branch of the model if convolution: # Use embedding and convolution layers for protein sequences self.protein_embedding = nn.Embedding(26, 20) # Use a list of parallel convolution and pooling layers with different window sizes self.protein_convs = nn.ModuleList() for window_size in protein_windows: conv = nn.Sequential( nn.Conv1d(20, n_filters, window_size, padding="same"), nn.BatchNorm1d(n_filters), activation, nn.AdaptiveMaxPool1d(1) ) self.protein_convs.append(conv) if protein_layers: self.protein_branch = [] for layer_size in protein_layers: self.protein_branch += [ nn.LazyLinear(layer_size), nn.BatchNorm1d(layer_size), activation, nn.Dropout(dropout) ] self.protein_branch = nn.Sequential(*self.protein_branch) # Define the final branch of the model that combines the drug and protein branches self.final_branch = [] if fc_layers: # Add additional dense layers for the final branch for layer_size in fc_layers: self.final_branch += [ nn.LazyLinear(layer_size), nn.BatchNorm1d(layer_size), activation ] self.final_branch = nn.Sequential(*self.final_branch) def forward(self, input_d, input_p): # Forward pass of the drug branch output_d = self.drug_branch(input_d.float()) # Forward pass of the protein branch if self.convolution: # Embed the protein sequence and transpose the dimensions output_p = self.protein_embedding(input_p) output_p = output_p.transpose(1, 2) # Apply the parallel convolution and pooling layers conv_outputs = [] for conv in self.protein_convs: conv_output = conv(output_p).squeeze(-1) conv_outputs.append(conv_output) # Concatenate the convolution outputs output_p = torch.cat(conv_outputs, dim=1) else: output_p = input_p if self.protein_layers: # Apply the additional dense layers to the protein branch output_p = self.protein_branch(output_p) # Concatenate the drug and protein outputs output_t = torch.cat([output_d, output_p], dim=1) # Apply the final dense layers output_t = self.final_branch(output_t) return output_t