import torch import torch.nn as nn import torch.nn.functional as F import math class TransformerCPI(nn.Module): def __init__(self, protein_dim, hidden_dim, n_layers, kernel_size, dropout, n_heads, pf_dim, atom_dim=34): super().__init__() self.encoder = Encoder(protein_dim, hidden_dim, n_layers, kernel_size, dropout) self.decoder = Decoder(atom_dim, hidden_dim, n_layers, n_heads, pf_dim, dropout) self.weight = nn.Parameter(torch.FloatTensor(atom_dim, atom_dim)) self.init_weight() def init_weight(self): stdv = 1. / math.sqrt(self.weight.size(1)) self.weight.data.uniform_(-stdv, stdv) def gcn(self, input, adj): # input =[batch,num_node, atom_dim] # adj = [batch,num_node, num_node] support = torch.matmul(input, self.weight) # support =[batch,num_node,atom_dim] output = torch.bmm(adj.float(), support.float()) # output = [batch,num_node,atom_dim] return output def forward(self, compound, protein): compound, adj = compound compound, compound_lengths = compound adj, _ = adj protein, protein_lengths = protein # compound = [batch,atom_num, atom_dim] # adj = [batch,atom_num, atom_num] # protein = [batch,protein len, 100] compound_mask = torch.arange(compound.size(1), device=compound.device) >= compound_lengths.unsqueeze(1) protein_mask = torch.arange(protein.size(1), device=protein.device) >= protein_lengths.unsqueeze(1) compound_mask = compound_mask.unsqueeze(1).unsqueeze(3) protein_mask = protein_mask.unsqueeze(1).unsqueeze(2) compound = self.gcn(compound.float(), adj) # compound = torch.unsqueeze(compound, dim=0) # compound = [batch size=1 ,atom_num, atom_dim] # protein = torch.unsqueeze(protein, dim=0) # protein =[ batch size=1,protein len, protein_dim] enc_src = self.encoder(protein) # enc_src = [batch size, protein len, hid dim] out = self.decoder(compound, enc_src, compound_mask, protein_mask) # out = [batch size, 2] # out = torch.squeeze(out, dim=0) return out class SelfAttention(nn.Module): def __init__(self, hidden_dim, n_heads, dropout): super().__init__() self.hidden_dim = hidden_dim self.n_heads = n_heads assert hidden_dim % n_heads == 0 self.w_q = nn.Linear(hidden_dim, hidden_dim) self.w_k = nn.Linear(hidden_dim, hidden_dim) self.w_v = nn.Linear(hidden_dim, hidden_dim) self.fc = nn.Linear(hidden_dim, hidden_dim) self.do = nn.Dropout(dropout) self.scale = (hidden_dim // n_heads) ** 0.5 def forward(self, query, key, value, mask=None): bsz = query.shape[0] # query = key = value [batch size, sent len, hid dim] q = self.w_q(query) k = self.w_k(key) v = self.w_v(value) # q, k, v = [batch size, sent len, hid dim] q = q.view(bsz, -1, self.n_heads, self.hidden_dim // self.n_heads).permute(0, 2, 1, 3) k = k.view(bsz, -1, self.n_heads, self.hidden_dim // self.n_heads).permute(0, 2, 1, 3) v = v.view(bsz, -1, self.n_heads, self.hidden_dim // self.n_heads).permute(0, 2, 1, 3) # k, v = [batch size, n heads, sent len_K, hid dim // n heads] # q = [batch size, n heads, sent len_q, hid dim // n heads] energy = torch.matmul(q, k.permute(0, 1, 3, 2)) / self.scale # energy = [batch size, n heads, sent len_Q, sent len_K] if mask is not None: energy = energy.masked_fill(mask == 0, -1e10) attention = self.do(F.softmax(energy, dim=-1)) # attention = [batch size, n heads, sent len_Q, sent len_K] x = torch.matmul(attention, v) # x = [batch size, n heads, sent len_Q, hid dim // n heads] x = x.permute(0, 2, 1, 3).contiguous() # x = [batch size, sent len_Q, n heads, hid dim // n heads] x = x.view(bsz, -1, self.n_heads * (self.hidden_dim // self.n_heads)) # x = [batch size, src sent len_Q, hid dim] x = self.fc(x) # x = [batch size, sent len_Q, hid dim] return x class Encoder(nn.Module): """protein feature extraction.""" def __init__(self, protein_dim, hidden_dim, n_layers, kernel_size, dropout): super().__init__() assert kernel_size % 2 == 1, "Kernel size must be odd (for now)" self.input_dim = protein_dim self.hidden_dim = hidden_dim self.kernel_size = kernel_size self.dropout = dropout self.n_layers = n_layers # self.pos_embedding = nn.Embedding(1000, hidden_dim) self.scale = 0.5 ** 0.5 self.convs = nn.ModuleList( [nn.Conv1d(hidden_dim, 2 * hidden_dim, kernel_size, padding=(kernel_size - 1) // 2) for _ in range(self.n_layers)]) # convolutional layers self.dropout = nn.Dropout(dropout) self.fc = nn.Linear(self.input_dim, self.hidden_dim) self.gn = nn.GroupNorm(8, hidden_dim * 2) self.ln = nn.LayerNorm(hidden_dim) def forward(self, protein): # pos = torch.arange(0, protein.shape[1]).unsqueeze(0).repeat(protein.shape[0], 1) # protein = protein + self.pos_embedding(pos) # protein = [batch size, protein len,protein_dim] conv_input = self.fc(protein.float()) # conv_input=[batch size,protein len,hid dim] # permute for convolutional layer conv_input = conv_input.permute(0, 2, 1) # conv_input = [batch size, hid dim, protein len] for i, conv in enumerate(self.convs): # pass through convolutional layer conved = conv(self.dropout(conv_input)) # conved = [batch size, 2*hid dim, protein len] # pass through GLU activation function conved = F.glu(conved, dim=1) # conved = [batch size, hid dim, protein len] # apply residual connection / high way conved = (conved + conv_input) * self.scale # conved = [batch size, hid dim, protein len] # set conv_input to conved for next loop iteration conv_input = conved conved = conved.permute(0, 2, 1) # conved = [batch size,protein len,hid dim] conved = self.ln(conved) return conved class PositionwiseFeedforward(nn.Module): def __init__(self, hidden_dim, pf_dim, dropout): super().__init__() self.hidden_dim = hidden_dim self.pf_dim = pf_dim self.fc_1 = nn.Conv1d(hidden_dim, pf_dim, 1) # convolution neural units self.fc_2 = nn.Conv1d(pf_dim, hidden_dim, 1) # convolution neural units self.do = nn.Dropout(dropout) def forward(self, x): # x = [batch size, sent len, hid dim] x = x.permute(0, 2, 1) # x = [batch size, hid dim, sent len] x = self.do(F.relu(self.fc_1(x))) # x = [batch size, pf dim, sent len] x = self.fc_2(x) # x = [batch size, hid dim, sent len] x = x.permute(0, 2, 1) # x = [batch size, sent len, hid dim] return x class DecoderLayer(nn.Module): def __init__(self, hidden_dim, n_heads, pf_dim, dropout, self_attention=SelfAttention, positionwise_feedforward=PositionwiseFeedforward): super().__init__() self.ln = nn.LayerNorm(hidden_dim) self.sa = self_attention(hidden_dim, n_heads, dropout) self.ea = self_attention(hidden_dim, n_heads, dropout) self.pf = positionwise_feedforward(hidden_dim, pf_dim, dropout) self.do = nn.Dropout(dropout) def forward(self, trg, src, trg_mask=None, src_mask=None): # trg = [batch_size, compound len, atom_dim] # src = [batch_size, protein len, hidden_dim] # encoder output # trg_mask = [batch size, compound sent len] # src_mask = [batch size, protein len] trg = self.ln(trg + self.do(self.sa(trg, trg, trg, trg_mask))) trg = self.ln(trg + self.do(self.ea(trg, src, src, src_mask))) trg = self.ln(trg + self.do(self.pf(trg))) return trg class Decoder(nn.Module): """ compound feature extraction.""" def __init__(self, atom_dim, hidden_dim, n_layers, n_heads, pf_dim, dropout, decoder_layer=DecoderLayer, self_attention=SelfAttention, positionwise_feedforward=PositionwiseFeedforward): super().__init__() self.ln = nn.LayerNorm(hidden_dim) self.output_dim = atom_dim self.hidden_dim = hidden_dim self.n_layers = n_layers self.n_heads = n_heads self.pf_dim = pf_dim self.decoder_layer = decoder_layer self.self_attention = self_attention self.positionwise_feedforward = positionwise_feedforward self.dropout = dropout self.sa = self_attention(hidden_dim, n_heads, dropout) self.layers = nn.ModuleList( [decoder_layer(hidden_dim, n_heads, pf_dim, dropout, self_attention, positionwise_feedforward) for _ in range(n_layers)]) self.ft = nn.Linear(atom_dim, hidden_dim) self.do = nn.Dropout(dropout) self.fc_1 = nn.Linear(hidden_dim, 256) # self.fc_2 = nn.Linear(256, 2) self.gn = nn.GroupNorm(8, 256) def forward(self, trg, src, trg_mask=None, src_mask=None): # trg = [batch_size, compound len, atom_dim] # src = [batch_size, protein len, hidden_dim] # encoder output trg = self.ft(trg) # trg = [batch size, compound len, hid dim] for layer in self.layers: trg = layer(trg, src, trg_mask, src_mask) # trg = [batch size, compound len, hid dim] """Use norm to determine which atom is significant. """ norm = torch.norm(trg, dim=2) # norm = [batch size,compound len] norm = F.softmax(norm, dim=1) # norm = [batch size,compound len] # trg = torch.squeeze(trg,dim=0) # norm = torch.squeeze(norm,dim=0) sum = torch.zeros((trg.shape[0], self.hidden_dim), device=trg.device) for i in range(norm.shape[0]): for j in range(norm.shape[1]): v = trg[i, j,] v = v * norm[i, j] sum[i,] += v # sum = [batch size,hidden_dim] label = F.relu(self.fc_1(sum)) # label = self.fc_2(label) return label