Spaces:
Running
Running
import functools | |
import einops | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from esm.layers.rotary import RotaryEmbedding | |
class MultiHeadAttention(nn.Module): | |
def __init__( | |
self, | |
d_model: int, | |
n_heads: int, | |
bias: bool = False, | |
qk_layernorm: bool = True, | |
): | |
super().__init__() | |
self.d_model = d_model | |
self.n_heads = n_heads | |
self.d_head = self.d_model // self.n_heads | |
self.layernorm_qkv = nn.Sequential( | |
nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=bias) | |
) | |
self.out_proj = nn.Linear(d_model, d_model, bias=bias) | |
if qk_layernorm: | |
self.q_ln = nn.LayerNorm(d_model, bias=bias) | |
self.k_ln = nn.LayerNorm(d_model, bias=bias) | |
else: | |
self.q_ln = nn.Identity() | |
self.k_ln = nn.Identity() | |
self.rotary = RotaryEmbedding(d_model // n_heads) | |
def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor): | |
q = q.unflatten(-1, (self.n_heads, self.d_head)) | |
k = k.unflatten(-1, (self.n_heads, self.d_head)) | |
q, k = self.rotary(q, k) | |
q = q.flatten(-2, -1) | |
k = k.flatten(-2, -1) | |
return q, k | |
def forward(self, x, seq_id): | |
qkv_BLD3 = self.layernorm_qkv(x) | |
query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1) | |
query_BLD, key_BLD = self.q_ln(query_BLD), self.k_ln(key_BLD) | |
query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD) | |
n_heads = self.n_heads | |
reshaper = functools.partial( | |
einops.rearrange, pattern="b s (h d) -> b h s d", h=n_heads | |
) | |
query_BHLD, key_BHLD, value_BHLD = map( | |
reshaper, (query_BLD, key_BLD, value_BLD) | |
) | |
# Where True, enable participation in attention. | |
mask_BLL = seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2) | |
mask_BHLL = mask_BLL.unsqueeze(1) | |
context_BHLD = F.scaled_dot_product_attention( | |
query_BHLD, key_BHLD, value_BHLD, mask_BHLL | |
) | |
context_BLD = einops.rearrange(context_BHLD, "b h s d -> b s (h d)") | |
return self.out_proj(context_BLD) | |