Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import torch | |
import torch.nn as nn | |
from lightning import LightningModule | |
class GraphDTA(LightningModule): | |
""" | |
From GraphDTA (Nguyen et al., 2020; https://doi.org/10.1093/bioinformatics/btaa921). | |
""" | |
def __init__( | |
self, | |
gnn: nn.Module, | |
num_features_protein: int, | |
n_filters: int, | |
embed_dim: int, | |
output_dim: int, | |
dropout: float | |
): | |
super().__init__() | |
self.gnn = gnn | |
# protein sequence encoder (1d conv) | |
self.embedding_xt = nn.Embedding(num_features_protein, embed_dim) | |
self.conv_xt = nn.LazyConv1d(out_channels=n_filters, kernel_size=8) | |
self.fc1_xt = nn.Linear(32 * 121, output_dim) | |
# combined layers | |
self.fc1 = nn.Linear(256, 1024) | |
self.fc2 = nn.Linear(1024, 512) | |
# activation and regularization | |
self.relu = nn.ReLU() | |
self.dropout = nn.Dropout(dropout) | |
# protein input feedforward | |
def conv_forward_xt(self, v_p): | |
v_p = self.embedding_xt(v_p.long()) | |
v_p = self.conv_xt(v_p) | |
# flatten | |
v_p = v_p.view(-1, 32 * 121) | |
v_p = self.fc1_xt(v_p) | |
return v_p | |
def forward(self, v_d, v_p): | |
v_d = self.gnn(v_d) | |
v_p = self.conv_forward_xt(v_p) | |
# concat | |
v_f = torch.cat((v_d, v_p), 1) | |
# dense layers | |
v_f = self.fc1(v_f) | |
v_f = self.relu(v_f) | |
v_f = self.dropout(v_f) | |
v_f = self.fc2(v_f) | |
v_f = self.relu(v_f) | |
v_f = self.dropout(v_f) | |
# v_f = self.out(v_f) | |
return v_f | |