Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import math | |
class TransformerCPI(nn.Module): | |
def __init__(self, protein_dim, hidden_dim, n_layers, kernel_size, dropout, n_heads, pf_dim, atom_dim=34): | |
super().__init__() | |
self.encoder = Encoder(protein_dim, hidden_dim, n_layers, kernel_size, dropout) | |
self.decoder = Decoder(atom_dim, hidden_dim, n_layers, n_heads, pf_dim, dropout) | |
self.weight = nn.Parameter(torch.FloatTensor(atom_dim, atom_dim)) | |
self.init_weight() | |
def init_weight(self): | |
stdv = 1. / math.sqrt(self.weight.size(1)) | |
self.weight.data.uniform_(-stdv, stdv) | |
def gcn(self, input, adj): | |
# input =[batch,num_node, atom_dim] | |
# adj = [batch,num_node, num_node] | |
support = torch.matmul(input, self.weight) | |
# support =[batch,num_node,atom_dim] | |
output = torch.bmm(adj.float(), support.float()) | |
# output = [batch,num_node,atom_dim] | |
return output | |
def forward(self, compound, protein): | |
compound, adj = compound | |
compound, compound_lengths = compound | |
adj, _ = adj | |
protein, protein_lengths = protein | |
# compound = [batch,atom_num, atom_dim] | |
# adj = [batch,atom_num, atom_num] | |
# protein = [batch,protein len, 100] | |
compound_mask = torch.arange(compound.size(1), device=compound.device) >= compound_lengths.unsqueeze(1) | |
protein_mask = torch.arange(protein.size(1), device=protein.device) >= protein_lengths.unsqueeze(1) | |
compound_mask = compound_mask.unsqueeze(1).unsqueeze(3) | |
protein_mask = protein_mask.unsqueeze(1).unsqueeze(2) | |
compound = self.gcn(compound.float(), adj) | |
# compound = torch.unsqueeze(compound, dim=0) | |
# compound = [batch size=1 ,atom_num, atom_dim] | |
# protein = torch.unsqueeze(protein, dim=0) | |
# protein =[ batch size=1,protein len, protein_dim] | |
enc_src = self.encoder(protein) | |
# enc_src = [batch size, protein len, hid dim] | |
out = self.decoder(compound, enc_src, compound_mask, protein_mask) | |
# out = [batch size, 2] | |
# out = torch.squeeze(out, dim=0) | |
return out | |
class SelfAttention(nn.Module): | |
def __init__(self, hidden_dim, n_heads, dropout): | |
super().__init__() | |
self.hidden_dim = hidden_dim | |
self.n_heads = n_heads | |
assert hidden_dim % n_heads == 0 | |
self.w_q = nn.Linear(hidden_dim, hidden_dim) | |
self.w_k = nn.Linear(hidden_dim, hidden_dim) | |
self.w_v = nn.Linear(hidden_dim, hidden_dim) | |
self.fc = nn.Linear(hidden_dim, hidden_dim) | |
self.do = nn.Dropout(dropout) | |
self.scale = (hidden_dim // n_heads) ** 0.5 | |
def forward(self, query, key, value, mask=None): | |
bsz = query.shape[0] | |
# query = key = value [batch size, sent len, hid dim] | |
q = self.w_q(query) | |
k = self.w_k(key) | |
v = self.w_v(value) | |
# q, k, v = [batch size, sent len, hid dim] | |
q = q.view(bsz, -1, self.n_heads, self.hidden_dim // self.n_heads).permute(0, 2, 1, 3) | |
k = k.view(bsz, -1, self.n_heads, self.hidden_dim // self.n_heads).permute(0, 2, 1, 3) | |
v = v.view(bsz, -1, self.n_heads, self.hidden_dim // self.n_heads).permute(0, 2, 1, 3) | |
# k, v = [batch size, n heads, sent len_K, hid dim // n heads] | |
# q = [batch size, n heads, sent len_q, hid dim // n heads] | |
energy = torch.matmul(q, k.permute(0, 1, 3, 2)) / self.scale | |
# energy = [batch size, n heads, sent len_Q, sent len_K] | |
if mask is not None: | |
energy = energy.masked_fill(mask == 0, -1e10) | |
attention = self.do(F.softmax(energy, dim=-1)) | |
# attention = [batch size, n heads, sent len_Q, sent len_K] | |
x = torch.matmul(attention, v) | |
# x = [batch size, n heads, sent len_Q, hid dim // n heads] | |
x = x.permute(0, 2, 1, 3).contiguous() | |
# x = [batch size, sent len_Q, n heads, hid dim // n heads] | |
x = x.view(bsz, -1, self.n_heads * (self.hidden_dim // self.n_heads)) | |
# x = [batch size, src sent len_Q, hid dim] | |
x = self.fc(x) | |
# x = [batch size, sent len_Q, hid dim] | |
return x | |
class Encoder(nn.Module): | |
"""protein feature extraction.""" | |
def __init__(self, protein_dim, hidden_dim, n_layers, kernel_size, dropout): | |
super().__init__() | |
assert kernel_size % 2 == 1, "Kernel size must be odd (for now)" | |
self.input_dim = protein_dim | |
self.hidden_dim = hidden_dim | |
self.kernel_size = kernel_size | |
self.dropout = dropout | |
self.n_layers = n_layers | |
# self.pos_embedding = nn.Embedding(1000, hidden_dim) | |
self.scale = 0.5 ** 0.5 | |
self.convs = nn.ModuleList( | |
[nn.Conv1d(hidden_dim, 2 * hidden_dim, kernel_size, padding=(kernel_size - 1) // 2) for _ in | |
range(self.n_layers)]) # convolutional layers | |
self.dropout = nn.Dropout(dropout) | |
self.fc = nn.Linear(self.input_dim, self.hidden_dim) | |
self.gn = nn.GroupNorm(8, hidden_dim * 2) | |
self.ln = nn.LayerNorm(hidden_dim) | |
def forward(self, protein): | |
# pos = torch.arange(0, protein.shape[1]).unsqueeze(0).repeat(protein.shape[0], 1) | |
# protein = protein + self.pos_embedding(pos) | |
# protein = [batch size, protein len,protein_dim] | |
conv_input = self.fc(protein.float()) | |
# conv_input=[batch size,protein len,hid dim] | |
# permute for convolutional layer | |
conv_input = conv_input.permute(0, 2, 1) | |
# conv_input = [batch size, hid dim, protein len] | |
for i, conv in enumerate(self.convs): | |
# pass through convolutional layer | |
conved = conv(self.dropout(conv_input)) | |
# conved = [batch size, 2*hid dim, protein len] | |
# pass through GLU activation function | |
conved = F.glu(conved, dim=1) | |
# conved = [batch size, hid dim, protein len] | |
# apply residual connection / high way | |
conved = (conved + conv_input) * self.scale | |
# conved = [batch size, hid dim, protein len] | |
# set conv_input to conved for next loop iteration | |
conv_input = conved | |
conved = conved.permute(0, 2, 1) | |
# conved = [batch size,protein len,hid dim] | |
conved = self.ln(conved) | |
return conved | |
class PositionwiseFeedforward(nn.Module): | |
def __init__(self, hidden_dim, pf_dim, dropout): | |
super().__init__() | |
self.hidden_dim = hidden_dim | |
self.pf_dim = pf_dim | |
self.fc_1 = nn.Conv1d(hidden_dim, pf_dim, 1) # convolution neural units | |
self.fc_2 = nn.Conv1d(pf_dim, hidden_dim, 1) # convolution neural units | |
self.do = nn.Dropout(dropout) | |
def forward(self, x): | |
# x = [batch size, sent len, hid dim] | |
x = x.permute(0, 2, 1) # x = [batch size, hid dim, sent len] | |
x = self.do(F.relu(self.fc_1(x))) # x = [batch size, pf dim, sent len] | |
x = self.fc_2(x) # x = [batch size, hid dim, sent len] | |
x = x.permute(0, 2, 1) # x = [batch size, sent len, hid dim] | |
return x | |
class DecoderLayer(nn.Module): | |
def __init__(self, hidden_dim, n_heads, pf_dim, dropout, | |
self_attention=SelfAttention, | |
positionwise_feedforward=PositionwiseFeedforward): | |
super().__init__() | |
self.ln = nn.LayerNorm(hidden_dim) | |
self.sa = self_attention(hidden_dim, n_heads, dropout) | |
self.ea = self_attention(hidden_dim, n_heads, dropout) | |
self.pf = positionwise_feedforward(hidden_dim, pf_dim, dropout) | |
self.do = nn.Dropout(dropout) | |
def forward(self, trg, src, trg_mask=None, src_mask=None): | |
# trg = [batch_size, compound len, atom_dim] | |
# src = [batch_size, protein len, hidden_dim] # encoder output | |
# trg_mask = [batch size, compound sent len] | |
# src_mask = [batch size, protein len] | |
trg = self.ln(trg + self.do(self.sa(trg, trg, trg, trg_mask))) | |
trg = self.ln(trg + self.do(self.ea(trg, src, src, src_mask))) | |
trg = self.ln(trg + self.do(self.pf(trg))) | |
return trg | |
class Decoder(nn.Module): | |
""" compound feature extraction.""" | |
def __init__(self, atom_dim, hidden_dim, n_layers, n_heads, pf_dim, dropout, | |
decoder_layer=DecoderLayer, | |
self_attention=SelfAttention, | |
positionwise_feedforward=PositionwiseFeedforward): | |
super().__init__() | |
self.ln = nn.LayerNorm(hidden_dim) | |
self.output_dim = atom_dim | |
self.hidden_dim = hidden_dim | |
self.n_layers = n_layers | |
self.n_heads = n_heads | |
self.pf_dim = pf_dim | |
self.decoder_layer = decoder_layer | |
self.self_attention = self_attention | |
self.positionwise_feedforward = positionwise_feedforward | |
self.dropout = dropout | |
self.sa = self_attention(hidden_dim, n_heads, dropout) | |
self.layers = nn.ModuleList( | |
[decoder_layer(hidden_dim, n_heads, pf_dim, dropout, self_attention, positionwise_feedforward) | |
for _ in range(n_layers)]) | |
self.ft = nn.Linear(atom_dim, hidden_dim) | |
self.do = nn.Dropout(dropout) | |
self.fc_1 = nn.Linear(hidden_dim, 256) | |
# self.fc_2 = nn.Linear(256, 2) | |
self.gn = nn.GroupNorm(8, 256) | |
def forward(self, trg, src, trg_mask=None, src_mask=None): | |
# trg = [batch_size, compound len, atom_dim] | |
# src = [batch_size, protein len, hidden_dim] # encoder output | |
trg = self.ft(trg) # trg = [batch size, compound len, hid dim] | |
for layer in self.layers: | |
trg = layer(trg, src, trg_mask, src_mask) # trg = [batch size, compound len, hid dim] | |
"""Use norm to determine which atom is significant. """ | |
norm = torch.norm(trg, dim=2) # norm = [batch size,compound len] | |
norm = F.softmax(norm, dim=1) # norm = [batch size,compound len] | |
# trg = torch.squeeze(trg,dim=0) | |
# norm = torch.squeeze(norm,dim=0) | |
sum = torch.zeros((trg.shape[0], self.hidden_dim), device=trg.device) | |
for i in range(norm.shape[0]): | |
for j in range(norm.shape[1]): | |
v = trg[i, j,] | |
v = v * norm[i, j] | |
sum[i,] += v # sum = [batch size,hidden_dim] | |
label = F.relu(self.fc_1(sum)) | |
# label = self.fc_2(label) | |
return label | |