Spaces:
Sleeping
Sleeping
# Copyright Generate Biomedicines, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
import torch.nn as nn | |
""" | |
们实现了Transformer模型中的关键组件:缩放点积注意力(Scaled Dot Product Attention)和多头注意力(Multi-Head Attention)。 | |
""" | |
class ScaledDotProductAttention(nn.Module): | |
"""Scaled dot product attention as described in Eqn 1 of Vaswani et al. 2017 [https://arxiv.org/abs/1706.03762]. | |
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V | |
Note that the dimension of the query has to match the dimension of the keys (here specified as ```d_k```) and the length of keys has to match | |
the length of the values. See for instance 'The Illustrated Transformer' [http://jalammar.github.io/illustrated-transformer/] | |
for pictorial depiction of attention. | |
Inputs: | |
Q (torch.tensor): of shape (batch_size, sequence_length_q, d_k) | |
K (torch.tensor): of shape (batch_size, sequence_length_k, d_k) | |
V (torch.tensor): of shape (batch_size, sequence_length_k, d_v) | |
mask (torch.tensor): of dtype (bool) or (byte) and shape (batch_size, 1, sequence_length_k), optional | |
zeroes (or False) indicate positions that cannot contribute to attention | |
Outputs: | |
output (torch.tensor) of shape (batch_size, sequence_length_q, d_v). The [i-j]-entry output[i,j,:] is formed as a convex combination of values: | |
\sum_k a_k V[i,k,:] and \sum_k a_k = 1. | |
attentions (torch.tensor) of shape (batch_size, sequence_length_q, sequence_length_k)) where the [b,i,j]-element | |
corresponds to the attention value (e.g relative contribution) of position j in the key-tensor to position i in the query tensor in element b of the batch. | |
""" | |
def __init__(self): | |
super(ScaledDotProductAttention, self).__init__() | |
self.softmax = nn.Softmax(dim=-1) | |
def forward(self, Q, K, V, mask=None): | |
_, _, d = K.size() | |
attn = torch.bmm(Q, K.transpose(1, 2)) / d ** 0.5 | |
if mask is not None: | |
attn = attn.float().masked_fill(mask == 0, -1e9) | |
attn = self.softmax(attn) | |
if mask is not None: | |
attn = attn.float().masked_fill(mask == 0, 0) | |
if V.dtype == torch.float16: | |
attn = attn.half() | |
output = torch.bmm(attn, V) | |
return output, attn | |
class MultiHeadAttention(nn.Module): | |
"""Multi-head attention with scaled dot product attention. See 'The Annotated Transformer' | |
http://nlp.seas.harvard.edu/2018/04/03/attention.html or 'The Illustrated Transformer' http://jalammar.github.io/illustrated-transformer/ | |
for details and intuition. | |
Args: | |
n_head (int): number of attention heads | |
d_k (int): dimension of the keys and queries in each attention head | |
d_v (int): dimension of the values in each attention head | |
d_model (int): input and output dimension for the layer | |
dropout (float): dropout rate, default is 0.1 | |
Inputs: | |
Q (torch.tensor): query tensor of shape ```(batch_size, sequence_length_q, d_model)``` | |
K (torch.tensor): key tensor of shape ```(batch_size, sequence_length_k, d_model)``` | |
V (torch.tensor): value tensor of shape ```(batch_size, sequence_length_k, d_model)``` | |
mask (torch.tensor): (optional) of dtype ```bool`` or ```byte``` and size (batch_size, 1, sequence_length_k), | |
zeroes (or False) indicate positions that cannot contribute to attention | |
Outputs: | |
output (torch.tensor) : of shape ```(batch_size, sequence_length_q, d_model)``` | |
attentions (torch.tensor): of shape ```(batch_size * n_head, sequence_length_q, sequence_length_k) where | |
```attentions[batch_size*(i):batch_size*(i+1),:,:]``` corresponds to the batch of attention blocks for i'th head. See | |
```chroma.layers.attention.ScaledDotProductAttention``` for more details | |
""" | |
def __init__(self, n_head, d_k, d_v, d_model, dropout=0.1): | |
super(MultiHeadAttention, self).__init__() | |
self.n_head = n_head | |
self.d_k = d_k | |
self.d_v = d_v | |
self.d_model = d_model | |
self.Wq = nn.Parameter(torch.Tensor(n_head, d_model, d_k)) | |
self.Wk = nn.Parameter(torch.Tensor(n_head, d_model, d_k)) | |
self.Wv = nn.Parameter(torch.Tensor(n_head, d_model, d_v)) | |
self.Wo = nn.Parameter(torch.Tensor(n_head * d_v, d_model)) | |
self.attention = ScaledDotProductAttention() | |
self.dropout = nn.Dropout(p=dropout) | |
self.reset_parameters() | |
def reset_parameters(self): | |
nn.init.xavier_normal_(self.Wq) | |
nn.init.xavier_normal_(self.Wk) | |
nn.init.xavier_normal_(self.Wv) | |
nn.init.kaiming_uniform_(self.Wo) | |
def forward(self, Q, K, V, bias=None, mask=None): | |
mb_size, len_q, d_q_in = Q.size() | |
mb_size, len_k, d_k_in = K.size() | |
mb_size, len_v, d_v_in = V.size() | |
d_model = self.d_model | |
if d_q_in != d_model: | |
raise ValueError("Dimension of Q does not match d_model.") | |
if d_k_in != d_model: | |
raise ValueError("Dimension of K does not match d_model.") | |
if d_v_in != d_model: | |
raise ValueError("Dimension of V does not match d_model.") | |
# treat as a (n_head) size batch and project to d_k and d_v | |
q_s = torch.cat([Q @ W for W in self.Wq]) # (n_head*mb_size) x len_q x d_k | |
k_s = torch.cat([K @ W for W in self.Wk]) # (n_head*mb_size) x len_k x d_k | |
v_s = torch.cat([V @ W for W in self.Wv]) # (n_head*mb_size) x len_v x d_v | |
# Attention | |
if mask is not None: | |
mask = mask.repeat(self.n_head, 1, 1) | |
outputs, attns = self.attention(q_s, k_s, v_s, mask=mask) | |
# Back to original mb_size batch, result size = mb_size x len_q x (n_head*d_v) | |
outputs = torch.cat(torch.split(outputs, mb_size, dim=0), dim=-1) | |
# Project back to residual size | |
outputs = outputs @ self.Wo | |
outputs = self.dropout(outputs) | |
return outputs, attns | |
class AttentionChainPool(nn.Module): | |
"""Pools residue-based representations to chain-based representations using a chain mask and attention. | |
Args: | |
n_head (int): number of attention heads | |
d_model (int): dimension of embeddings to be pooled | |
Inputs: | |
h (torch.tensor): of size (batch_size, sequence_length, d_model) | |
C (torch.tensor): of size (batch_size, sequence_length) | |
Outputs: | |
output (torch.tensor): of size (batch_size, n_chains, d_model) | |
chain_mask (torch.tensor): of size (batch_size, n_chains) | |
""" | |
def __init__(self, n_head, d_model): | |
super().__init__() | |
self.attention = MultiHeadAttention( | |
n_head, d_model, d_model, d_model, dropout=0.0 | |
) | |
def get_query(self, x): | |
return torch.ones(x.size(0), 1, x.size(2)).type(x.dtype).to(x.device) | |
def forward(self, h, C): | |
bs, num_res = C.size() | |
chains = C.abs().unique() | |
chains = ( | |
chains[chains > 0].unsqueeze(-1).repeat(1, bs).reshape(-1).unsqueeze(-1) | |
) | |
num_chains = len(chains.unique()) | |
h_repeat = h.repeat(num_chains, 1, 1) | |
C_repeat = C.repeat(num_chains, 1) | |
mask = (C_repeat == chains).unsqueeze(-2) | |
output, _ = self.attention( | |
self.get_query(h_repeat), h_repeat, h_repeat, mask=mask | |
) | |
output = torch.cat(output.split(bs), 1) | |
chain_mask = torch.stack(mask.squeeze(1).any(dim=-1).split(bs), -1) | |
return output, chain_mask | |
class Attention(nn.Module): | |
""" | |
A multi-head attention layer with optional gating and bias as implemented in Jumper et al. (2021) | |
Args: | |
n_head (int): Number of heads of attention | |
d_model (int): Dimension of input and outputs | |
d_k (int): Dimension of keys/queries | |
d_v (int): Dimension of values | |
gate (bool): Whether to include a gate connection (as in Jumper et al. (2021)) | |
Inputs: | |
Q (torch.tensor): of size (batch_size, num_queries, d_model) | |
K (torch.tensor): of size (batch_size, num_keys, d_model) | |
V (torch.tensor): of size (batch_size, num_keys, d_model) | |
bias (torch.tensor): (optional) of size (batch_size, n_head, num_queries, num_keys) | |
mask (torch.tensor): (optional) of size (batch_size, n_head, num_queries, num_keys) | |
Outputs: | |
output (torch.tensor): of size (batch_size, num_queries, d_model) | |
""" | |
def __init__(self, n_head, d_model, d_k=None, d_v=None, gate=False): | |
super().__init__() | |
self.n_head = n_head | |
self.d_model = d_model | |
self.d_k = d_model // n_head if d_k is None else d_k | |
self.d_v = d_model // n_head if d_v is None else d_v | |
self.gate = gate | |
self.q_weights = nn.Parameter(torch.Tensor(d_model, n_head, self.d_k)) | |
self.k_weights = nn.Parameter(torch.Tensor(d_model, n_head, self.d_k)) | |
self.v_weights = nn.Parameter(torch.Tensor(d_model, n_head, self.d_v)) | |
self.o_weights = nn.Parameter(torch.Tensor(n_head, self.d_v, d_model)) | |
self.o_bias = nn.Parameter(torch.Tensor(d_model)) | |
if self.gate: | |
self.g_weights = nn.Parameter(torch.Tensor(d_model, n_head, self.d_v)) | |
self.g_bias = nn.Parameter(torch.Tensor(n_head, self.d_v)) | |
self.softmax = nn.Softmax(dim=-1) | |
self.reset_parameters() | |
def reset_parameters(self): | |
nn.init.xavier_uniform_(self.q_weights) | |
nn.init.xavier_uniform_(self.k_weights) | |
nn.init.xavier_uniform_(self.v_weights) | |
nn.init.xavier_uniform_(self.o_weights) | |
nn.init.zeros_(self.o_bias) | |
if self.gate: | |
nn.init.zeros_(self.g_weights) | |
nn.init.ones_(self.g_bias) | |
def forward(self, Q, K, V, bias=None, mask=None): | |
self._check_inputs(Q, K, V, bias, mask) | |
q = torch.einsum("bqa,ahc->bqhc", Q, self.q_weights) * self.d_k ** (-0.5) | |
k = torch.einsum("bka,ahc->bkhc", K, self.k_weights) | |
v = torch.einsum("bka,ahc->bkhc", V, self.v_weights) | |
logits = torch.einsum("bqhc,bkhc->bhqk", q, k) | |
if bias is not None: | |
logits = logits + bias | |
weights = torch.nn.functional.softmax(logits, dim=-1) | |
if mask is not None: | |
weights = weights.masked_fill(~mask, 0.0) | |
weighted_avg = torch.einsum("bhqk,bkhc->bqhc", weights, v) | |
if self.gate: | |
gate_values = torch.einsum("bqa,ahc->bqhc", Q, self.g_weights) + self.g_bias | |
gate_values = torch.sigmoid(gate_values, dim=-1) | |
weighted_avg = weighted_avg * gate_values | |
output = ( | |
torch.einsum("bqhc,hco->bqo", weighted_avg, self.o_weights) + self.o_bias | |
) | |
return output | |
def _check_inputs(self, Q, K, V, bias, mask): | |
batch_size_q, num_queries, d_q_in = Q.size() | |
batch_size_k, num_keys, d_k_in = K.size() | |
batch_size_v, num_values, d_v_in = V.size() | |
if d_q_in != self.d_model: | |
raise ValueError( | |
f"Dimension of Q tensor needs to be (batch_size, number_queries, d_model)" | |
) | |
if d_k_in != self.d_model: | |
raise ValueError( | |
f"Dimension of K tensor needs to be (batch_size, number_keys, d_model)" | |
) | |
if d_v_in != self.d_model: | |
raise ValueError( | |
f"Dimension of V tensor needs to be (batch_size, number_values, d_model)" | |
) | |
if num_keys != num_values: | |
raise ValueError(f"Number of keys needs to match number of values passed") | |
if (batch_size_q != batch_size_k) or (batch_size_k != batch_size_v): | |
raise ValueError( | |
f"Found batch size mismatch among inputs, all tensors must agree in size of dimension 0" | |
) | |
if bias is not None: | |
if (bias.dim() != 3) and (bias.dim() != 4): | |
raise ValueError( | |
f"Bias specified but dimension mismatched: passed {bias.dim()}-dimensional tensor but should be 3-dimensional" | |
f"of shape (n_head, num_queries, num_keys) or 4-dimensional of shape (batch_size, n_head, num_queries, num_keys)" | |
) | |
if bias.dim() == 3: | |
n_head_b, num_queries_b, num_keys_b = bias.size() | |
if n_head_b != self.n_head: | |
raise ValueError( | |
f"Bias specified but number of heads (dim of axis=0) does not match number of heads: {self.n_head}" | |
) | |
if num_queries_b != num_queries: | |
raise ValueError( | |
f"Bias specified but number of queries (dim of axis=1) does not match number of queries given in Q tensor" | |
) | |
if num_keys_b != num_keys: | |
raise ValueError( | |
f"Bias specified but number of keys (dim of axis=2) does not match number of queries given in K tensor " | |
f"(dimenson of axis=1)" | |
) | |
elif bias.dim() == 4: | |
if bias.dim() == 3: | |
n_batch_b, n_head_b, num_queries_b, num_keys_b = bias.size() | |
if n_head_b != self.n_head: | |
raise ValueError( | |
f"Bias specified but number of heads (dim of axis=0) does not match number of heads: {self.n_head}" | |
) | |
if num_queries_b != num_queries: | |
raise ValueError( | |
f"Bias specified but number of queries (dim of axis=1) does not match number of queries given in Q tensor" | |
) | |
if num_keys_b != num_keys: | |
raise ValueError( | |
f"Bias specified but number of keys (dim of axis=2) does not match number of queries given in K tensor " | |
f"(dimenson of axis=1)" | |
) | |
if mask is not None: | |
if mask.dtype != torch.bool: | |
raise ValueError( | |
f"Mask specified but not given by correct dtype, should be torch.bool but found {mask.dtype}" | |
) | |
if mask.dim() != 4: | |
raise ValueError( | |
f"Mask specified but dimension mismatched: passed {mask.dim()}-dimensional tensor but should be 4-dimensional" | |
f"of shape (batch_size, n_head, num_queries, num_keys)" | |
) | |
batch_size_b, _, num_queries_b, num_keys_b = mask.size() | |
if (num_queries_b != num_queries) and (num_queries_b != 1): | |
raise ValueError( | |
f"Bias specified but number of queries (dim of axis=2) does not match number of queries given in Q tensor" | |
) | |
if (num_keys_b != num_keys) and (num_keys_b != 1): | |
raise ValueError( | |
f"Bias specified but number of keys (dim of axis=3) does not match number of queries given in K tensor " | |
f"(dimenson of axis=1)" | |
) | |