import torch from torch import nn from xformers.ops import LowerTriangularMask, memory_efficient_attention, unbind import os XFORMERS_DISABLED = os.environ.get("XFORMERS_DISABLED", "false").lower() == "true" class BasicSelfAttention(nn.Module): def __init__( self, num_heads: int, d_model: int, qkv_bias: bool = False, proj_bias: bool = True, qk_norm: bool = True, use_mup: bool = True, attn_drop: float = 0.0, ) -> None: super().__init__() self.num_heads = num_heads self.head_dim = d_model // num_heads # Scaling by 8 to be equal when head_dim=64 self.scale = 8/self.head_dim if use_mup else self.head_dim**-0.5 self.qkv = nn.Linear(d_model, d_model * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(d_model, d_model, bias=proj_bias) self.qk_norm = qk_norm if self.qk_norm: # qk normalization https://arxiv.org/pdf/2302.05442 # Note that LN is done in fp32, so they have to be self.norm = nn.LayerNorm(self.head_dim, eps=1e-05) def forward(self, x: torch.Tensor, causal: bool = False) -> torch.Tensor: B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] if self.qk_norm: q = self.norm(q) k = self.norm(k) # LN done in float32, cast back to bf16 q = q.to(dtype=v.dtype) k = k.to(dtype=v.dtype) q *= self.scale attn = q @ k.transpose(-2, -1) if causal: mask_value = -torch.finfo(attn.dtype).max i, j = attn.shape[-2:] mask = ~torch.tril(torch.ones(i, j)).bool().to(attn.device) attn = attn.masked_fill(mask, mask_value) attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) return x class BasicCrossAttention(nn.Module): def __init__( self, num_heads: int, d_model: int, k_model: int, qkv_bias: bool = False, proj_bias: bool = True, qk_norm: bool = True, use_mup: bool = True, attn_drop: float = 0.0, ) -> None: super().__init__() self.num_heads = num_heads self.head_dim = d_model // num_heads # Scaling by 8 to be equal when head_dim=64 self.scale = 8/self.head_dim if use_mup else self.head_dim**-0.5 # self.qkv = nn.Linear(d_model, d_model * 3, bias=qkv_bias) self.to_q = nn.Linear(d_model, d_model, bias=qkv_bias) self.to_k = nn.Linear(d_model, d_model, bias=qkv_bias) self.to_v = nn.Linear(d_model, d_model, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(d_model, d_model, bias=proj_bias) self.qk_norm = qk_norm if self.qk_norm: # qk normalization https://arxiv.org/pdf/2302.05442 # Note that LN is done in fp32, so they have to be self.norm = nn.LayerNorm(self.head_dim, eps=1e-05) def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False) -> torch.Tensor: """ q: (b s) t c k: (b) t c """ B, N, C = q.shape k = k.repeat(B // len(k), 1, 1) v = v.repeat(B // len(v), 1, 1) k = k[:, :q.shape[1]] v = v[:, :q.shape[1]] B, M, _ = k.shape # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # q, k, v = qkv[0], qkv[1], qkv[2] q = self.to_q(q).reshape(B, N, self.num_heads, self.head_dim) k = self.to_k(k).reshape(B, M, self.num_heads, self.head_dim) v = self.to_v(v).reshape(B, M, self.num_heads, self.head_dim) if self.qk_norm: q = self.norm(q) k = self.norm(k) # LN done in float32, cast back to bf16 q = q.to(dtype=v.dtype) k = k.to(dtype=v.dtype) q *= self.scale attn = q @ k.transpose(-2, -1) if causal: mask_value = -torch.finfo(attn.dtype).max i, j = attn.shape[-2:] mask = ~torch.tril(torch.ones(i, j)).bool().to(attn.device) attn = attn.masked_fill(mask, mask_value) attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) return x class MemoryEfficientAttention(BasicSelfAttention): # NOTE: Mem-eff attention from xformers is actually Flash Attention 2 def forward(self, x: torch.Tensor, causal: bool = False) -> torch.Tensor: B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) q, k, v = unbind(qkv, 2) if self.qk_norm: q = self.norm(q) k = self.norm(k) # LN done in float32, cast back to bf16 q = q.to(dtype=v.dtype) k = k.to(dtype=v.dtype) attn_bias = LowerTriangularMask() if causal else None x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, scale=self.scale) # x = x.reshape([B, N, C]) x = self.proj(x) return x if XFORMERS_DISABLED: SelfAttention = BasicSelfAttention else: SelfAttention = MemoryEfficientAttention