hma / genie /attention.py
LeroyWaa's picture
draft
246c106
raw
history blame
5.46 kB
import torch
from torch import nn
from xformers.ops import LowerTriangularMask, memory_efficient_attention, unbind
import os
XFORMERS_DISABLED = os.environ.get("XFORMERS_DISABLED", "false").lower() == "true"
class BasicSelfAttention(nn.Module):
def __init__(
self,
num_heads: int,
d_model: int,
qkv_bias: bool = False,
proj_bias: bool = True,
qk_norm: bool = True,
use_mup: bool = True,
attn_drop: float = 0.0,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = d_model // num_heads
# Scaling by 8 to be equal when head_dim=64
self.scale = 8/self.head_dim if use_mup else self.head_dim**-0.5
self.qkv = nn.Linear(d_model, d_model * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(d_model, d_model, bias=proj_bias)
self.qk_norm = qk_norm
if self.qk_norm:
# qk normalization https://arxiv.org/pdf/2302.05442
# Note that LN is done in fp32, so they have to be
self.norm = nn.LayerNorm(self.head_dim, eps=1e-05)
def forward(self, x: torch.Tensor, causal: bool = False) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
if self.qk_norm:
q = self.norm(q)
k = self.norm(k)
# LN done in float32, cast back to bf16
q = q.to(dtype=v.dtype)
k = k.to(dtype=v.dtype)
q *= self.scale
attn = q @ k.transpose(-2, -1)
if causal:
mask_value = -torch.finfo(attn.dtype).max
i, j = attn.shape[-2:]
mask = ~torch.tril(torch.ones(i, j)).bool().to(attn.device)
attn = attn.masked_fill(mask, mask_value)
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class BasicCrossAttention(nn.Module):
def __init__(
self,
num_heads: int,
d_model: int,
k_model: int,
qkv_bias: bool = False,
proj_bias: bool = True,
qk_norm: bool = True,
use_mup: bool = True,
attn_drop: float = 0.0,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = d_model // num_heads
# Scaling by 8 to be equal when head_dim=64
self.scale = 8/self.head_dim if use_mup else self.head_dim**-0.5
# self.qkv = nn.Linear(d_model, d_model * 3, bias=qkv_bias)
self.to_q = nn.Linear(d_model, d_model, bias=qkv_bias)
self.to_k = nn.Linear(d_model, d_model, bias=qkv_bias)
self.to_v = nn.Linear(d_model, d_model, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(d_model, d_model, bias=proj_bias)
self.qk_norm = qk_norm
if self.qk_norm:
# qk normalization https://arxiv.org/pdf/2302.05442
# Note that LN is done in fp32, so they have to be
self.norm = nn.LayerNorm(self.head_dim, eps=1e-05)
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False) -> torch.Tensor:
"""
q: (b s) t c
k: (b) t c
"""
B, N, C = q.shape
k = k.repeat(B // len(k), 1, 1)
v = v.repeat(B // len(v), 1, 1)
k = k[:, :q.shape[1]]
v = v[:, :q.shape[1]]
B, M, _ = k.shape
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
# q, k, v = qkv[0], qkv[1], qkv[2]
q = self.to_q(q).reshape(B, N, self.num_heads, self.head_dim)
k = self.to_k(k).reshape(B, M, self.num_heads, self.head_dim)
v = self.to_v(v).reshape(B, M, self.num_heads, self.head_dim)
if self.qk_norm:
q = self.norm(q)
k = self.norm(k)
# LN done in float32, cast back to bf16
q = q.to(dtype=v.dtype)
k = k.to(dtype=v.dtype)
q *= self.scale
attn = q @ k.transpose(-2, -1)
if causal:
mask_value = -torch.finfo(attn.dtype).max
i, j = attn.shape[-2:]
mask = ~torch.tril(torch.ones(i, j)).bool().to(attn.device)
attn = attn.masked_fill(mask, mask_value)
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class MemoryEfficientAttention(BasicSelfAttention):
# NOTE: Mem-eff attention from xformers is actually Flash Attention 2
def forward(self, x: torch.Tensor, causal: bool = False) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
q, k, v = unbind(qkv, 2)
if self.qk_norm:
q = self.norm(q)
k = self.norm(k)
# LN done in float32, cast back to bf16
q = q.to(dtype=v.dtype)
k = k.to(dtype=v.dtype)
attn_bias = LowerTriangularMask() if causal else None
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, scale=self.scale) #
x = x.reshape([B, N, C])
x = self.proj(x)
return x
if XFORMERS_DISABLED:
SelfAttention = BasicSelfAttention
else:
SelfAttention = MemoryEfficientAttention