import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint import einops from einops import rearrange, repeat from inspect import isfunction from .rotary import RotaryEmbedding from .modules import RMSNorm if hasattr(nn.functional, 'scaled_dot_product_attention'): ATTENTION_MODE = 'flash' else: ATTENTION_MODE = 'math' print(f'attention mode is {ATTENTION_MODE}') def add_mask(sim, mask): b, ndim = sim.shape[0], mask.ndim if ndim == 3: mask = rearrange(mask, "b n m -> b 1 n m") if ndim == 2: mask = repeat(mask, "n m -> b 1 n m", b=b) max_neg_value = -torch.finfo(sim.dtype).max sim = sim.masked_fill(~mask, max_neg_value) return sim def create_mask(q_shape, k_shape, device, q_mask=None, k_mask=None): def default(val, d): return val if val is not None else (d() if isfunction(d) else d) b, i, j, device = q_shape[0], q_shape[-2], k_shape[-2], device q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool)) k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool)) attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j') return attn_mask class Attention(nn.Module): def __init__(self, dim, context_dim=None, num_heads=8, qkv_bias=False, qk_scale=None, qk_norm=None, attn_drop=0., proj_drop=0., rope_mode='none'): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 if context_dim is None: self.cross_attn = False else: self.cross_attn = True context_dim = dim if context_dim is None else context_dim self.to_q = nn.Linear(dim, dim, bias=qkv_bias) self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias) self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias) if qk_norm is None: self.norm_q = nn.Identity() self.norm_k = nn.Identity() elif qk_norm == 'layernorm': self.norm_q = nn.LayerNorm(head_dim) self.norm_k = nn.LayerNorm(head_dim) elif qk_norm == 'rmsnorm': self.norm_q = RMSNorm(head_dim) self.norm_k = RMSNorm(head_dim) else: raise NotImplementedError self.attn_drop_p = attn_drop self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) if self.cross_attn: assert rope_mode == 'none' self.rope_mode = rope_mode if self.rope_mode == 'shared' or self.rope_mode == 'x_only': self.rotary = RotaryEmbedding(dim=head_dim) elif self.rope_mode == 'dual': self.rotary_x = RotaryEmbedding(dim=head_dim) self.rotary_c = RotaryEmbedding(dim=head_dim) def _rotary(self, q, k, extras): if self.rope_mode == 'shared': q, k = self.rotary(q=q, k=k) elif self.rope_mode == 'x_only': q_x, k_x = self.rotary(q=q[:, :, extras:, :], k=k[:, :, extras:, :]) q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :] q = torch.cat((q_c, q_x), dim=2) k = torch.cat((k_c, k_x), dim=2) elif self.rope_mode == 'dual': q_x, k_x = self.rotary_x(q=q[:, :, extras:, :], k=k[:, :, extras:, :]) q_c, k_c = self.rotary_c(q=q[:, :, :extras, :], k=k[:, :, :extras, :]) q = torch.cat((q_c, q_x), dim=2) k = torch.cat((k_c, k_x), dim=2) elif self.rope_mode == 'none': pass else: raise NotImplementedError return q, k def _attn(self, q, k, v, mask_binary): if ATTENTION_MODE == 'flash': x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop_p, attn_mask=mask_binary) x = einops.rearrange(x, 'B H L D -> B L (H D)') elif ATTENTION_MODE == 'math': attn = (q @ k.transpose(-2, -1)) * self.scale attn = add_mask(attn, mask_binary) if mask_binary is not None else attn attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2) x = einops.rearrange(x, 'B H L D -> B L (H D)') else: raise NotImplementedError return x def forward(self, x, context=None, context_mask=None, extras=0): B, L, C = x.shape if context is None: context = x q = self.to_q(x) k = self.to_k(context) v = self.to_v(context) if context_mask is not None: mask_binary = create_mask(x.shape, context.shape, x.device, None, context_mask) else: mask_binary = None q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads) k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads) v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads) q = self.norm_q(q) k = self.norm_k(k) q, k = self._rotary(q, k, extras) x = self._attn(q, k, v, mask_binary) x = self.proj(x) x = self.proj_drop(x) return x class JointAttention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, qk_norm=None, attn_drop=0., proj_drop=0., rope_mode='none'): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 self.to_qx, self.to_kx, self.to_vx = self._make_qkv_layers(dim, qkv_bias) self.to_qc, self.to_kc, self.to_vc = self._make_qkv_layers(dim, qkv_bias) self.norm_qx, self.norm_kx = self._make_norm_layers(qk_norm, head_dim) self.norm_qc, self.norm_kc = self._make_norm_layers(qk_norm, head_dim) self.attn_drop_p = attn_drop self.attn_drop = nn.Dropout(attn_drop) self.proj_x = nn.Linear(dim, dim) self.proj_drop_x = nn.Dropout(proj_drop) self.proj_c = nn.Linear(dim, dim) self.proj_drop_c = nn.Dropout(proj_drop) self.rope_mode = rope_mode if self.rope_mode == 'shared' or self.rope_mode == 'x_only': self.rotary = RotaryEmbedding(dim=head_dim) elif self.rope_mode == 'dual': self.rotary_x = RotaryEmbedding(dim=head_dim) self.rotary_c = RotaryEmbedding(dim=head_dim) def _make_qkv_layers(self, dim, qkv_bias): return (nn.Linear(dim, dim, bias=qkv_bias), nn.Linear(dim, dim, bias=qkv_bias), nn.Linear(dim, dim, bias=qkv_bias)) def _make_norm_layers(self, qk_norm, head_dim): if qk_norm is None: norm_q = nn.Identity() norm_k = nn.Identity() elif qk_norm == 'layernorm': norm_q = nn.LayerNorm(head_dim) norm_k = nn.LayerNorm(head_dim) elif qk_norm == 'rmsnorm': norm_q = RMSNorm(head_dim) norm_k = RMSNorm(head_dim) else: raise NotImplementedError return norm_q, norm_k def _rotary(self, q, k, extras): if self.rope_mode == 'shared': q, k = self.rotary(q=q, k=k) elif self.rope_mode == 'x_only': q_x, k_x = self.rotary(q=q[:, :, extras:, :], k=k[:, :, extras:, :]) q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :] q = torch.cat((q_c, q_x), dim=2) k = torch.cat((k_c, k_x), dim=2) elif self.rope_mode == 'dual': q_x, k_x = self.rotary_x(q=q[:, :, extras:, :], k=k[:, :, extras:, :]) q_c, k_c = self.rotary_c(q=q[:, :, :extras, :], k=k[:, :, :extras, :]) q = torch.cat((q_c, q_x), dim=2) k = torch.cat((k_c, k_x), dim=2) elif self.rope_mode == 'none': pass else: raise NotImplementedError return q, k def _attn(self, q, k, v, mask_binary): if ATTENTION_MODE == 'flash': x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop_p, attn_mask=mask_binary) x = einops.rearrange(x, 'B H L D -> B L (H D)') elif ATTENTION_MODE == 'math': attn = (q @ k.transpose(-2, -1)) * self.scale attn = add_mask(attn, mask_binary) if mask_binary is not None else attn attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2) x = einops.rearrange(x, 'B H L D -> B L (H D)') else: raise NotImplementedError return x def _cat_mask(self, x, context, x_mask=None, context_mask=None): B = x.shape[0] if x_mask is None: x_mask = torch.ones(B, x.shape[-2], device=x.device).bool() if context_mask is None: context_mask = torch.ones(B, context.shape[-2], device=context.device).bool() mask = torch.cat([context_mask, x_mask], dim=1) return mask def forward(self, x, context, x_mask=None, context_mask=None, extras=0): B, Lx, C = x.shape _, Lc, _ = context.shape if x_mask is not None or context_mask is not None: mask = self._cat_mask(x, context, x_mask=x_mask, context_mask=context_mask) shape = [B, Lx+Lc, C] mask_binary = create_mask(q_shape=shape, k_shape=shape, device=x.device, q_mask=None, k_mask=mask) else: mask_binary = None qx, kx, vx = self.to_qx(x), self.to_kx(x), self.to_vx(x) qc, kc, vc = self.to_qc(context), self.to_kc(context), self.to_vc(context) qx, kx, vx = map(lambda t: einops.rearrange(t, 'B L (H D) -> B H L D', H=self.num_heads), [qx, kx, vx]) qc, kc, vc = map(lambda t: einops.rearrange(t, 'B L (H D) -> B H L D', H=self.num_heads), [qc, kc, vc]) qx, kx = self.norm_qx(qx), self.norm_kx(kx) qc, kc = self.norm_qc(qc), self.norm_kc(kc) q, k, v = (torch.cat([qc, qx], dim=2), torch.cat([kc, kx], dim=2), torch.cat([vc, vx], dim=2)) q, k = self._rotary(q, k, extras) x = self._attn(q, k, v, mask_binary) context, x = x[:, :Lc, :], x[:, Lc:, :] x = self.proj_x(x) x = self.proj_drop_x(x) context = self.proj_c(context) context = self.proj_drop_c(context) return x, context