M3Site / esm /layers /attention.py
anonymousforpaper's picture
Upload 103 files
224a33f verified
import functools
import einops
import torch
import torch.nn.functional as F
from torch import nn
from esm.layers.rotary import RotaryEmbedding
class MultiHeadAttention(nn.Module):
def __init__(
self,
d_model: int,
n_heads: int,
bias: bool = False,
qk_layernorm: bool = True,
):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_head = self.d_model // self.n_heads
self.layernorm_qkv = nn.Sequential(
nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=bias)
)
self.out_proj = nn.Linear(d_model, d_model, bias=bias)
if qk_layernorm:
self.q_ln = nn.LayerNorm(d_model, bias=bias)
self.k_ln = nn.LayerNorm(d_model, bias=bias)
else:
self.q_ln = nn.Identity()
self.k_ln = nn.Identity()
self.rotary = RotaryEmbedding(d_model // n_heads)
def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor):
q = q.unflatten(-1, (self.n_heads, self.d_head))
k = k.unflatten(-1, (self.n_heads, self.d_head))
q, k = self.rotary(q, k)
q = q.flatten(-2, -1)
k = k.flatten(-2, -1)
return q, k
def forward(self, x, seq_id):
qkv_BLD3 = self.layernorm_qkv(x)
query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
query_BLD, key_BLD = self.q_ln(query_BLD), self.k_ln(key_BLD)
query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
n_heads = self.n_heads
reshaper = functools.partial(
einops.rearrange, pattern="b s (h d) -> b h s d", h=n_heads
)
query_BHLD, key_BHLD, value_BHLD = map(
reshaper, (query_BLD, key_BLD, value_BLD)
)
# Where True, enable participation in attention.
mask_BLL = seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2)
mask_BHLL = mask_BLL.unsqueeze(1)
context_BHLD = F.scaled_dot_product_attention(
query_BHLD, key_BHLD, value_BHLD, mask_BHLL
)
context_BLD = einops.rearrange(context_BHLD, "b h s d -> b s (h d)")
return self.out_proj(context_BLD)