Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from math import floor | |
import re | |
from typing import Literal | |
import numpy as np | |
import torch.nn as nn | |
import torch | |
import torch.nn.functional as F | |
def conv(in_channels, out_channels, kernel_size, conv_dim, stride=1): | |
conv_layer = None | |
match conv_dim: | |
case 1: | |
conv_layer = nn.Conv1d | |
case 2: | |
conv_layer = nn.Conv2d | |
case 3: | |
conv_layer = nn.Conv3d | |
return conv_layer(in_channels, out_channels, | |
kernel_size=kernel_size, stride=stride, padding=floor(kernel_size / 2), bias=False) | |
def batch_norm(out_channels, conv_dim): | |
bn_layer = None | |
match conv_dim: | |
case 1: | |
bn_layer = nn.BatchNorm1d | |
case 2: | |
bn_layer = nn.BatchNorm2d | |
case 3: | |
bn_layer = nn.BatchNorm3d | |
return bn_layer(out_channels) | |
def conv3x3(in_channels, out_channels, stride=1): | |
return nn.Conv2d(in_channels, out_channels, kernel_size=3, | |
stride=stride, padding=1, bias=False) | |
def conv5x5(in_channels, out_channels, stride=1): | |
return nn.Conv2d(in_channels, out_channels, kernel_size=5, | |
stride=stride, padding=2, bias=False) | |
def conv1x1(in_channels, out_channels, stride=1): | |
return nn.Conv2d(in_channels, out_channels, kernel_size=1, | |
stride=stride, padding=0, bias=False) | |
# Residual block | |
class ResidualBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, conv_dim, stride=1, downsample=None): | |
super().__init__() | |
# self.conv1 = conv5x5(in_channels, out_channels, stride) | |
self.conv1 = conv(in_channels, out_channels, kernel_size=5, conv_dim=conv_dim, stride=stride) | |
self.bn1 = batch_norm(out_channels, conv_dim=conv_dim) | |
self.elu = nn.ELU(inplace=True) | |
# self.conv2 = conv3x3(out_channels, out_channels) | |
self.conv2 = conv(out_channels, out_channels, kernel_size=3, conv_dim=conv_dim, stride=stride) | |
self.bn2 = batch_norm(out_channels, conv_dim=conv_dim) | |
self.downsample = downsample | |
def forward(self, x): | |
residual = x | |
out = self.conv1(x) | |
out = self.bn1(out) | |
out = self.elu(out) | |
out = self.conv2(out) | |
out = self.bn2(out) | |
if self.downsample: | |
residual = self.downsample(x) | |
out += residual | |
out = self.elu(out) | |
return out | |
class DrugVQA(nn.Module): | |
""" | |
The class is an implementation of the DrugVQA model including regularization and without pruning. | |
Slight modifications have been done for speedup | |
""" | |
def __init__( | |
self, | |
conv_dim: Literal[1, 2, 3], | |
lstm_hid_dim: int, | |
d_a: int, | |
r: int, | |
n_chars_smi: int, | |
n_chars_seq: int, | |
dropout: float, | |
in_channels: int, | |
cnn_channels: int, | |
cnn_layers: int, | |
emb_dim: int, | |
dense_hid: int, | |
): | |
""" | |
lstm_hid_dim: {int} hidden dimension for lstm | |
d_a : {int} hidden dimension for the dense layer | |
r : {int} attention-hops or attention heads | |
n_chars_smi : {int} voc size of smiles | |
n_chars_seq : {int} voc size of protein sequence | |
dropout : {float} | |
in_channels : {int} channels of CNN block input | |
cnn_channels: {int} channels of CNN block | |
cnn_layers : {int} num of layers of each CNN block | |
emb_dim : {int} embeddings dimension | |
dense_hid : {int} hidden dim for the output dense | |
""" | |
super().__init__() | |
self.conv_dim = conv_dim | |
self.lstm_hid_dim = lstm_hid_dim | |
self.r = r | |
self.in_channels = in_channels | |
# rnn | |
self.embeddings = nn.Embedding(n_chars_smi, emb_dim) | |
# self.seq_embed = nn.Embedding(n_chars_seq, emb_dim) | |
self.lstm = nn.LSTM(emb_dim, self.lstm_hid_dim, 2, batch_first=True, bidirectional=True, | |
dropout=dropout) | |
self.linear_first = nn.Linear(2 * self.lstm_hid_dim, d_a) | |
self.linear_second = nn.Linear(d_a, r) | |
self.linear_first_seq = nn.Linear(cnn_channels, d_a) | |
self.linear_second_seq = nn.Linear(d_a, self.r) | |
# cnn | |
# self.conv = conv3x3(1, self.in_channels) | |
self.conv = conv(1, self.in_channels, kernel_size=3, conv_dim=conv_dim) | |
self.bn = batch_norm(in_channels, conv_dim=conv_dim) | |
self.elu = nn.ELU(inplace=False) | |
self.layer1 = self.make_layer(cnn_channels, cnn_layers) | |
self.layer2 = self.make_layer(cnn_channels, cnn_layers) | |
self.linear_final_step = nn.Linear(self.lstm_hid_dim * 2 + d_a, dense_hid) | |
# self.linear_final = nn.Linear(dense_hid, n_classes) | |
self.softmax = nn.Softmax(dim=1) | |
# @staticmethod | |
# def softmax(input, axis=1): | |
# """ | |
# Softmax applied to axis=n | |
# Args: | |
# input: {Tensor,Variable} input on which softmax is to be applied | |
# axis : {int} axis on which softmax is to be applied | |
# | |
# Returns: | |
# softmaxed tensors | |
# """ | |
# input_size = input.size() | |
# trans_input = input.transpose(axis, len(input_size) - 1) | |
# trans_size = trans_input.size() | |
# input_2d = trans_input.contiguous().view(-1, trans_size[-1]) | |
# soft_max_2d = F.softmax(input_2d) | |
# soft_max_nd = soft_max_2d.view(*trans_size) | |
# return soft_max_nd.transpose(axis, len(input_size) - 1) | |
def make_layer(self, out_channels, blocks, stride=1): | |
downsample = None | |
if (stride != 1) or (self.in_channels != out_channels): | |
downsample = nn.Sequential( | |
# conv3x3(self.in_channels, out_channels, stride=stride), | |
conv(self.in_channels, out_channels, kernel_size=3, conv_dim=self.conv_dim, stride=stride), | |
batch_norm(out_channels, conv_dim=self.conv_dim) | |
) | |
layers = [ResidualBlock(self.in_channels, out_channels, | |
conv_dim=self.conv_dim, stride=stride, downsample=downsample)] | |
self.in_channels = out_channels | |
for i in range(1, blocks): | |
layers.append(ResidualBlock(out_channels, out_channels, conv_dim=self.conv_dim)) | |
return nn.Sequential(*layers) | |
def forward(self, enc_drug, enc_protein): | |
enc_drug, _ = enc_drug | |
enc_protein, _ = enc_protein | |
smile_embed = self.embeddings(enc_drug.long()) | |
# self.hidden_state = tuple(hidden_state.to(smile_embed).detach() for hidden_state in self.hidden_state) | |
outputs, hidden_state = self.lstm(smile_embed) | |
sentence_att = F.tanh(self.linear_first(outputs)) | |
sentence_att = self.linear_second(sentence_att) | |
sentence_att = self.softmax(sentence_att) | |
sentence_att = sentence_att.transpose(1, 2) | |
sentence_embed = sentence_att @ outputs | |
avg_sentence_embed = torch.sum(sentence_embed, 1) / self.r # multi head | |
pic = self.conv(enc_protein.float().unsqueeze(1)) | |
pic = self.bn(pic) | |
pic = self.elu(pic) | |
pic = self.layer1(pic) | |
pic = self.layer2(pic) | |
pic_emb = torch.mean(pic, 2).unsqueeze(2) | |
pic_emb = pic_emb.permute(0, 2, 1) | |
seq_att = F.tanh(self.linear_first_seq(pic_emb)) | |
seq_att = self.linear_second_seq(seq_att) | |
seq_att = self.softmax(seq_att) | |
seq_att = seq_att.transpose(1, 2) | |
seq_embed = seq_att @ pic_emb | |
avg_seq_embed = torch.sum(seq_embed, 1) / self.r | |
sscomplex = torch.cat([avg_sentence_embed, avg_seq_embed], dim=1) | |
sscomplex = F.relu(self.linear_final_step(sscomplex)) | |
# if not bool(self.type): | |
# output = F.sigmoid(self.linear_final(sscomplex)) | |
# return output, seq_att | |
# else: | |
# return F.log_softmax(self.linear_final(sscomplex)), seq_att | |
return sscomplex, seq_att | |
class AttentionL2Regularization(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, seq_att): | |
batch_size = seq_att.size(0) | |
identity = torch.eye(seq_att.size(1), device=seq_att.device) | |
identity = identity.unsqueeze(0).expand(batch_size, seq_att.size(1), seq_att.size(1)) | |
loss = torch.mean(self.l2_matrix_norm(seq_att @ seq_att.transpose(1, 2) - identity)) | |
return loss | |
def l2_matrix_norm(m): | |
""" | |
m = ||A * A_T - I|| | |
Missing from the original DrugVQA GitHub source code. | |
Opting to use the faster Frobenius norm rather than the induced L2 matrix norm (spectral norm) | |
proposed in the original research, because the goal is to minimize the difference between | |
the attention matrix and the identity matrix. | |
""" | |
return torch.linalg.norm(m, ord='fro', dim=(1, 2)) | |