Hukuna's picture
Upload 221 files
ce7bf5b verified
# Copyright Generate Biomedicines, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
"""
们实现了Transformer模型中的关键组件:缩放点积注意力(Scaled Dot Product Attention)和多头注意力(Multi-Head Attention)。
"""
class ScaledDotProductAttention(nn.Module):
"""Scaled dot product attention as described in Eqn 1 of Vaswani et al. 2017 [https://arxiv.org/abs/1706.03762].
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
Note that the dimension of the query has to match the dimension of the keys (here specified as ```d_k```) and the length of keys has to match
the length of the values. See for instance 'The Illustrated Transformer' [http://jalammar.github.io/illustrated-transformer/]
for pictorial depiction of attention.
Inputs:
Q (torch.tensor): of shape (batch_size, sequence_length_q, d_k)
K (torch.tensor): of shape (batch_size, sequence_length_k, d_k)
V (torch.tensor): of shape (batch_size, sequence_length_k, d_v)
mask (torch.tensor): of dtype (bool) or (byte) and shape (batch_size, 1, sequence_length_k), optional
zeroes (or False) indicate positions that cannot contribute to attention
Outputs:
output (torch.tensor) of shape (batch_size, sequence_length_q, d_v). The [i-j]-entry output[i,j,:] is formed as a convex combination of values:
\sum_k a_k V[i,k,:] and \sum_k a_k = 1.
attentions (torch.tensor) of shape (batch_size, sequence_length_q, sequence_length_k)) where the [b,i,j]-element
corresponds to the attention value (e.g relative contribution) of position j in the key-tensor to position i in the query tensor in element b of the batch.
"""
def __init__(self):
super(ScaledDotProductAttention, self).__init__()
self.softmax = nn.Softmax(dim=-1)
def forward(self, Q, K, V, mask=None):
_, _, d = K.size()
attn = torch.bmm(Q, K.transpose(1, 2)) / d ** 0.5
if mask is not None:
attn = attn.float().masked_fill(mask == 0, -1e9)
attn = self.softmax(attn)
if mask is not None:
attn = attn.float().masked_fill(mask == 0, 0)
if V.dtype == torch.float16:
attn = attn.half()
output = torch.bmm(attn, V)
return output, attn
class MultiHeadAttention(nn.Module):
"""Multi-head attention with scaled dot product attention. See 'The Annotated Transformer'
http://nlp.seas.harvard.edu/2018/04/03/attention.html or 'The Illustrated Transformer' http://jalammar.github.io/illustrated-transformer/
for details and intuition.
Args:
n_head (int): number of attention heads
d_k (int): dimension of the keys and queries in each attention head
d_v (int): dimension of the values in each attention head
d_model (int): input and output dimension for the layer
dropout (float): dropout rate, default is 0.1
Inputs:
Q (torch.tensor): query tensor of shape ```(batch_size, sequence_length_q, d_model)```
K (torch.tensor): key tensor of shape ```(batch_size, sequence_length_k, d_model)```
V (torch.tensor): value tensor of shape ```(batch_size, sequence_length_k, d_model)```
mask (torch.tensor): (optional) of dtype ```bool`` or ```byte``` and size (batch_size, 1, sequence_length_k),
zeroes (or False) indicate positions that cannot contribute to attention
Outputs:
output (torch.tensor) : of shape ```(batch_size, sequence_length_q, d_model)```
attentions (torch.tensor): of shape ```(batch_size * n_head, sequence_length_q, sequence_length_k) where
```attentions[batch_size*(i):batch_size*(i+1),:,:]``` corresponds to the batch of attention blocks for i'th head. See
```chroma.layers.attention.ScaledDotProductAttention``` for more details
"""
def __init__(self, n_head, d_k, d_v, d_model, dropout=0.1):
super(MultiHeadAttention, self).__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
self.d_model = d_model
self.Wq = nn.Parameter(torch.Tensor(n_head, d_model, d_k))
self.Wk = nn.Parameter(torch.Tensor(n_head, d_model, d_k))
self.Wv = nn.Parameter(torch.Tensor(n_head, d_model, d_v))
self.Wo = nn.Parameter(torch.Tensor(n_head * d_v, d_model))
self.attention = ScaledDotProductAttention()
self.dropout = nn.Dropout(p=dropout)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_normal_(self.Wq)
nn.init.xavier_normal_(self.Wk)
nn.init.xavier_normal_(self.Wv)
nn.init.kaiming_uniform_(self.Wo)
def forward(self, Q, K, V, bias=None, mask=None):
mb_size, len_q, d_q_in = Q.size()
mb_size, len_k, d_k_in = K.size()
mb_size, len_v, d_v_in = V.size()
d_model = self.d_model
if d_q_in != d_model:
raise ValueError("Dimension of Q does not match d_model.")
if d_k_in != d_model:
raise ValueError("Dimension of K does not match d_model.")
if d_v_in != d_model:
raise ValueError("Dimension of V does not match d_model.")
# treat as a (n_head) size batch and project to d_k and d_v
q_s = torch.cat([Q @ W for W in self.Wq]) # (n_head*mb_size) x len_q x d_k
k_s = torch.cat([K @ W for W in self.Wk]) # (n_head*mb_size) x len_k x d_k
v_s = torch.cat([V @ W for W in self.Wv]) # (n_head*mb_size) x len_v x d_v
# Attention
if mask is not None:
mask = mask.repeat(self.n_head, 1, 1)
outputs, attns = self.attention(q_s, k_s, v_s, mask=mask)
# Back to original mb_size batch, result size = mb_size x len_q x (n_head*d_v)
outputs = torch.cat(torch.split(outputs, mb_size, dim=0), dim=-1)
# Project back to residual size
outputs = outputs @ self.Wo
outputs = self.dropout(outputs)
return outputs, attns
class AttentionChainPool(nn.Module):
"""Pools residue-based representations to chain-based representations using a chain mask and attention.
Args:
n_head (int): number of attention heads
d_model (int): dimension of embeddings to be pooled
Inputs:
h (torch.tensor): of size (batch_size, sequence_length, d_model)
C (torch.tensor): of size (batch_size, sequence_length)
Outputs:
output (torch.tensor): of size (batch_size, n_chains, d_model)
chain_mask (torch.tensor): of size (batch_size, n_chains)
"""
def __init__(self, n_head, d_model):
super().__init__()
self.attention = MultiHeadAttention(
n_head, d_model, d_model, d_model, dropout=0.0
)
def get_query(self, x):
return torch.ones(x.size(0), 1, x.size(2)).type(x.dtype).to(x.device)
def forward(self, h, C):
bs, num_res = C.size()
chains = C.abs().unique()
chains = (
chains[chains > 0].unsqueeze(-1).repeat(1, bs).reshape(-1).unsqueeze(-1)
)
num_chains = len(chains.unique())
h_repeat = h.repeat(num_chains, 1, 1)
C_repeat = C.repeat(num_chains, 1)
mask = (C_repeat == chains).unsqueeze(-2)
output, _ = self.attention(
self.get_query(h_repeat), h_repeat, h_repeat, mask=mask
)
output = torch.cat(output.split(bs), 1)
chain_mask = torch.stack(mask.squeeze(1).any(dim=-1).split(bs), -1)
return output, chain_mask
class Attention(nn.Module):
"""
A multi-head attention layer with optional gating and bias as implemented in Jumper et al. (2021)
Args:
n_head (int): Number of heads of attention
d_model (int): Dimension of input and outputs
d_k (int): Dimension of keys/queries
d_v (int): Dimension of values
gate (bool): Whether to include a gate connection (as in Jumper et al. (2021))
Inputs:
Q (torch.tensor): of size (batch_size, num_queries, d_model)
K (torch.tensor): of size (batch_size, num_keys, d_model)
V (torch.tensor): of size (batch_size, num_keys, d_model)
bias (torch.tensor): (optional) of size (batch_size, n_head, num_queries, num_keys)
mask (torch.tensor): (optional) of size (batch_size, n_head, num_queries, num_keys)
Outputs:
output (torch.tensor): of size (batch_size, num_queries, d_model)
"""
def __init__(self, n_head, d_model, d_k=None, d_v=None, gate=False):
super().__init__()
self.n_head = n_head
self.d_model = d_model
self.d_k = d_model // n_head if d_k is None else d_k
self.d_v = d_model // n_head if d_v is None else d_v
self.gate = gate
self.q_weights = nn.Parameter(torch.Tensor(d_model, n_head, self.d_k))
self.k_weights = nn.Parameter(torch.Tensor(d_model, n_head, self.d_k))
self.v_weights = nn.Parameter(torch.Tensor(d_model, n_head, self.d_v))
self.o_weights = nn.Parameter(torch.Tensor(n_head, self.d_v, d_model))
self.o_bias = nn.Parameter(torch.Tensor(d_model))
if self.gate:
self.g_weights = nn.Parameter(torch.Tensor(d_model, n_head, self.d_v))
self.g_bias = nn.Parameter(torch.Tensor(n_head, self.d_v))
self.softmax = nn.Softmax(dim=-1)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.q_weights)
nn.init.xavier_uniform_(self.k_weights)
nn.init.xavier_uniform_(self.v_weights)
nn.init.xavier_uniform_(self.o_weights)
nn.init.zeros_(self.o_bias)
if self.gate:
nn.init.zeros_(self.g_weights)
nn.init.ones_(self.g_bias)
def forward(self, Q, K, V, bias=None, mask=None):
self._check_inputs(Q, K, V, bias, mask)
q = torch.einsum("bqa,ahc->bqhc", Q, self.q_weights) * self.d_k ** (-0.5)
k = torch.einsum("bka,ahc->bkhc", K, self.k_weights)
v = torch.einsum("bka,ahc->bkhc", V, self.v_weights)
logits = torch.einsum("bqhc,bkhc->bhqk", q, k)
if bias is not None:
logits = logits + bias
weights = torch.nn.functional.softmax(logits, dim=-1)
if mask is not None:
weights = weights.masked_fill(~mask, 0.0)
weighted_avg = torch.einsum("bhqk,bkhc->bqhc", weights, v)
if self.gate:
gate_values = torch.einsum("bqa,ahc->bqhc", Q, self.g_weights) + self.g_bias
gate_values = torch.sigmoid(gate_values, dim=-1)
weighted_avg = weighted_avg * gate_values
output = (
torch.einsum("bqhc,hco->bqo", weighted_avg, self.o_weights) + self.o_bias
)
return output
def _check_inputs(self, Q, K, V, bias, mask):
batch_size_q, num_queries, d_q_in = Q.size()
batch_size_k, num_keys, d_k_in = K.size()
batch_size_v, num_values, d_v_in = V.size()
if d_q_in != self.d_model:
raise ValueError(
f"Dimension of Q tensor needs to be (batch_size, number_queries, d_model)"
)
if d_k_in != self.d_model:
raise ValueError(
f"Dimension of K tensor needs to be (batch_size, number_keys, d_model)"
)
if d_v_in != self.d_model:
raise ValueError(
f"Dimension of V tensor needs to be (batch_size, number_values, d_model)"
)
if num_keys != num_values:
raise ValueError(f"Number of keys needs to match number of values passed")
if (batch_size_q != batch_size_k) or (batch_size_k != batch_size_v):
raise ValueError(
f"Found batch size mismatch among inputs, all tensors must agree in size of dimension 0"
)
if bias is not None:
if (bias.dim() != 3) and (bias.dim() != 4):
raise ValueError(
f"Bias specified but dimension mismatched: passed {bias.dim()}-dimensional tensor but should be 3-dimensional"
f"of shape (n_head, num_queries, num_keys) or 4-dimensional of shape (batch_size, n_head, num_queries, num_keys)"
)
if bias.dim() == 3:
n_head_b, num_queries_b, num_keys_b = bias.size()
if n_head_b != self.n_head:
raise ValueError(
f"Bias specified but number of heads (dim of axis=0) does not match number of heads: {self.n_head}"
)
if num_queries_b != num_queries:
raise ValueError(
f"Bias specified but number of queries (dim of axis=1) does not match number of queries given in Q tensor"
)
if num_keys_b != num_keys:
raise ValueError(
f"Bias specified but number of keys (dim of axis=2) does not match number of queries given in K tensor "
f"(dimenson of axis=1)"
)
elif bias.dim() == 4:
if bias.dim() == 3:
n_batch_b, n_head_b, num_queries_b, num_keys_b = bias.size()
if n_head_b != self.n_head:
raise ValueError(
f"Bias specified but number of heads (dim of axis=0) does not match number of heads: {self.n_head}"
)
if num_queries_b != num_queries:
raise ValueError(
f"Bias specified but number of queries (dim of axis=1) does not match number of queries given in Q tensor"
)
if num_keys_b != num_keys:
raise ValueError(
f"Bias specified but number of keys (dim of axis=2) does not match number of queries given in K tensor "
f"(dimenson of axis=1)"
)
if mask is not None:
if mask.dtype != torch.bool:
raise ValueError(
f"Mask specified but not given by correct dtype, should be torch.bool but found {mask.dtype}"
)
if mask.dim() != 4:
raise ValueError(
f"Mask specified but dimension mismatched: passed {mask.dim()}-dimensional tensor but should be 4-dimensional"
f"of shape (batch_size, n_head, num_queries, num_keys)"
)
batch_size_b, _, num_queries_b, num_keys_b = mask.size()
if (num_queries_b != num_queries) and (num_queries_b != 1):
raise ValueError(
f"Bias specified but number of queries (dim of axis=2) does not match number of queries given in Q tensor"
)
if (num_keys_b != num_keys) and (num_keys_b != 1):
raise ValueError(
f"Bias specified but number of keys (dim of axis=3) does not match number of queries given in K tensor "
f"(dimenson of axis=1)"
)