libokj's picture
Upload 299 files
22761bf verified
raw
history blame
10.5 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class TransformerCPI(nn.Module):
def __init__(self, protein_dim, hidden_dim, n_layers, kernel_size, dropout, n_heads, pf_dim, atom_dim=34):
super().__init__()
self.encoder = Encoder(protein_dim, hidden_dim, n_layers, kernel_size, dropout)
self.decoder = Decoder(atom_dim, hidden_dim, n_layers, n_heads, pf_dim, dropout)
self.weight = nn.Parameter(torch.FloatTensor(atom_dim, atom_dim))
self.init_weight()
def init_weight(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
def gcn(self, input, adj):
# input =[batch,num_node, atom_dim]
# adj = [batch,num_node, num_node]
support = torch.matmul(input, self.weight)
# support =[batch,num_node,atom_dim]
output = torch.bmm(adj.float(), support.float())
# output = [batch,num_node,atom_dim]
return output
def forward(self, compound, protein):
compound, adj = compound
compound, compound_lengths = compound
adj, _ = adj
protein, protein_lengths = protein
# compound = [batch,atom_num, atom_dim]
# adj = [batch,atom_num, atom_num]
# protein = [batch,protein len, 100]
compound_mask = torch.arange(compound.size(1), device=compound.device) >= compound_lengths.unsqueeze(1)
protein_mask = torch.arange(protein.size(1), device=protein.device) >= protein_lengths.unsqueeze(1)
compound_mask = compound_mask.unsqueeze(1).unsqueeze(3)
protein_mask = protein_mask.unsqueeze(1).unsqueeze(2)
compound = self.gcn(compound.float(), adj)
# compound = torch.unsqueeze(compound, dim=0)
# compound = [batch size=1 ,atom_num, atom_dim]
# protein = torch.unsqueeze(protein, dim=0)
# protein =[ batch size=1,protein len, protein_dim]
enc_src = self.encoder(protein)
# enc_src = [batch size, protein len, hid dim]
out = self.decoder(compound, enc_src, compound_mask, protein_mask)
# out = [batch size, 2]
# out = torch.squeeze(out, dim=0)
return out
class SelfAttention(nn.Module):
def __init__(self, hidden_dim, n_heads, dropout):
super().__init__()
self.hidden_dim = hidden_dim
self.n_heads = n_heads
assert hidden_dim % n_heads == 0
self.w_q = nn.Linear(hidden_dim, hidden_dim)
self.w_k = nn.Linear(hidden_dim, hidden_dim)
self.w_v = nn.Linear(hidden_dim, hidden_dim)
self.fc = nn.Linear(hidden_dim, hidden_dim)
self.do = nn.Dropout(dropout)
self.scale = (hidden_dim // n_heads) ** 0.5
def forward(self, query, key, value, mask=None):
bsz = query.shape[0]
# query = key = value [batch size, sent len, hid dim]
q = self.w_q(query)
k = self.w_k(key)
v = self.w_v(value)
# q, k, v = [batch size, sent len, hid dim]
q = q.view(bsz, -1, self.n_heads, self.hidden_dim // self.n_heads).permute(0, 2, 1, 3)
k = k.view(bsz, -1, self.n_heads, self.hidden_dim // self.n_heads).permute(0, 2, 1, 3)
v = v.view(bsz, -1, self.n_heads, self.hidden_dim // self.n_heads).permute(0, 2, 1, 3)
# k, v = [batch size, n heads, sent len_K, hid dim // n heads]
# q = [batch size, n heads, sent len_q, hid dim // n heads]
energy = torch.matmul(q, k.permute(0, 1, 3, 2)) / self.scale
# energy = [batch size, n heads, sent len_Q, sent len_K]
if mask is not None:
energy = energy.masked_fill(mask == 0, -1e10)
attention = self.do(F.softmax(energy, dim=-1))
# attention = [batch size, n heads, sent len_Q, sent len_K]
x = torch.matmul(attention, v)
# x = [batch size, n heads, sent len_Q, hid dim // n heads]
x = x.permute(0, 2, 1, 3).contiguous()
# x = [batch size, sent len_Q, n heads, hid dim // n heads]
x = x.view(bsz, -1, self.n_heads * (self.hidden_dim // self.n_heads))
# x = [batch size, src sent len_Q, hid dim]
x = self.fc(x)
# x = [batch size, sent len_Q, hid dim]
return x
class Encoder(nn.Module):
"""protein feature extraction."""
def __init__(self, protein_dim, hidden_dim, n_layers, kernel_size, dropout):
super().__init__()
assert kernel_size % 2 == 1, "Kernel size must be odd (for now)"
self.input_dim = protein_dim
self.hidden_dim = hidden_dim
self.kernel_size = kernel_size
self.dropout = dropout
self.n_layers = n_layers
# self.pos_embedding = nn.Embedding(1000, hidden_dim)
self.scale = 0.5 ** 0.5
self.convs = nn.ModuleList(
[nn.Conv1d(hidden_dim, 2 * hidden_dim, kernel_size, padding=(kernel_size - 1) // 2) for _ in
range(self.n_layers)]) # convolutional layers
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(self.input_dim, self.hidden_dim)
self.gn = nn.GroupNorm(8, hidden_dim * 2)
self.ln = nn.LayerNorm(hidden_dim)
def forward(self, protein):
# pos = torch.arange(0, protein.shape[1]).unsqueeze(0).repeat(protein.shape[0], 1)
# protein = protein + self.pos_embedding(pos)
# protein = [batch size, protein len,protein_dim]
conv_input = self.fc(protein.float())
# conv_input=[batch size,protein len,hid dim]
# permute for convolutional layer
conv_input = conv_input.permute(0, 2, 1)
# conv_input = [batch size, hid dim, protein len]
for i, conv in enumerate(self.convs):
# pass through convolutional layer
conved = conv(self.dropout(conv_input))
# conved = [batch size, 2*hid dim, protein len]
# pass through GLU activation function
conved = F.glu(conved, dim=1)
# conved = [batch size, hid dim, protein len]
# apply residual connection / high way
conved = (conved + conv_input) * self.scale
# conved = [batch size, hid dim, protein len]
# set conv_input to conved for next loop iteration
conv_input = conved
conved = conved.permute(0, 2, 1)
# conved = [batch size,protein len,hid dim]
conved = self.ln(conved)
return conved
class PositionwiseFeedforward(nn.Module):
def __init__(self, hidden_dim, pf_dim, dropout):
super().__init__()
self.hidden_dim = hidden_dim
self.pf_dim = pf_dim
self.fc_1 = nn.Conv1d(hidden_dim, pf_dim, 1) # convolution neural units
self.fc_2 = nn.Conv1d(pf_dim, hidden_dim, 1) # convolution neural units
self.do = nn.Dropout(dropout)
def forward(self, x):
# x = [batch size, sent len, hid dim]
x = x.permute(0, 2, 1) # x = [batch size, hid dim, sent len]
x = self.do(F.relu(self.fc_1(x))) # x = [batch size, pf dim, sent len]
x = self.fc_2(x) # x = [batch size, hid dim, sent len]
x = x.permute(0, 2, 1) # x = [batch size, sent len, hid dim]
return x
class DecoderLayer(nn.Module):
def __init__(self, hidden_dim, n_heads, pf_dim, dropout,
self_attention=SelfAttention,
positionwise_feedforward=PositionwiseFeedforward):
super().__init__()
self.ln = nn.LayerNorm(hidden_dim)
self.sa = self_attention(hidden_dim, n_heads, dropout)
self.ea = self_attention(hidden_dim, n_heads, dropout)
self.pf = positionwise_feedforward(hidden_dim, pf_dim, dropout)
self.do = nn.Dropout(dropout)
def forward(self, trg, src, trg_mask=None, src_mask=None):
# trg = [batch_size, compound len, atom_dim]
# src = [batch_size, protein len, hidden_dim] # encoder output
# trg_mask = [batch size, compound sent len]
# src_mask = [batch size, protein len]
trg = self.ln(trg + self.do(self.sa(trg, trg, trg, trg_mask)))
trg = self.ln(trg + self.do(self.ea(trg, src, src, src_mask)))
trg = self.ln(trg + self.do(self.pf(trg)))
return trg
class Decoder(nn.Module):
""" compound feature extraction."""
def __init__(self, atom_dim, hidden_dim, n_layers, n_heads, pf_dim, dropout,
decoder_layer=DecoderLayer,
self_attention=SelfAttention,
positionwise_feedforward=PositionwiseFeedforward):
super().__init__()
self.ln = nn.LayerNorm(hidden_dim)
self.output_dim = atom_dim
self.hidden_dim = hidden_dim
self.n_layers = n_layers
self.n_heads = n_heads
self.pf_dim = pf_dim
self.decoder_layer = decoder_layer
self.self_attention = self_attention
self.positionwise_feedforward = positionwise_feedforward
self.dropout = dropout
self.sa = self_attention(hidden_dim, n_heads, dropout)
self.layers = nn.ModuleList(
[decoder_layer(hidden_dim, n_heads, pf_dim, dropout, self_attention, positionwise_feedforward)
for _ in range(n_layers)])
self.ft = nn.Linear(atom_dim, hidden_dim)
self.do = nn.Dropout(dropout)
self.fc_1 = nn.Linear(hidden_dim, 256)
# self.fc_2 = nn.Linear(256, 2)
self.gn = nn.GroupNorm(8, 256)
def forward(self, trg, src, trg_mask=None, src_mask=None):
# trg = [batch_size, compound len, atom_dim]
# src = [batch_size, protein len, hidden_dim] # encoder output
trg = self.ft(trg) # trg = [batch size, compound len, hid dim]
for layer in self.layers:
trg = layer(trg, src, trg_mask, src_mask) # trg = [batch size, compound len, hid dim]
"""Use norm to determine which atom is significant. """
norm = torch.norm(trg, dim=2) # norm = [batch size,compound len]
norm = F.softmax(norm, dim=1) # norm = [batch size,compound len]
# trg = torch.squeeze(trg,dim=0)
# norm = torch.squeeze(norm,dim=0)
sum = torch.zeros((trg.shape[0], self.hidden_dim), device=trg.device)
for i in range(norm.shape[0]):
for j in range(norm.shape[1]):
v = trg[i, j,]
v = v * norm[i, j]
sum[i,] += v # sum = [batch size,hidden_dim]
label = F.relu(self.fc_1(sum))
# label = self.fc_2(label)
return label