|
from typing import *
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from .full_attn import scaled_dot_product_attention
|
|
|
|
|
|
class MultiHeadRMSNorm(nn.Module):
|
|
def __init__(self, dim: int, heads: int):
|
|
super().__init__()
|
|
self.scale = dim ** 0.5
|
|
self.gamma = nn.Parameter(torch.ones(heads, dim))
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
|
|
|
|
|
|
class RotaryPositionEmbedder(nn.Module):
|
|
def __init__(self, hidden_size: int, in_channels: int = 3):
|
|
super().__init__()
|
|
assert hidden_size % 2 == 0, "Hidden size must be divisible by 2"
|
|
self.hidden_size = hidden_size
|
|
self.in_channels = in_channels
|
|
self.freq_dim = hidden_size // in_channels // 2
|
|
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
|
|
self.freqs = 1.0 / (10000 ** self.freqs)
|
|
|
|
def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
|
|
self.freqs = self.freqs.to(indices.device)
|
|
phases = torch.outer(indices, self.freqs)
|
|
phases = torch.polar(torch.ones_like(phases), phases)
|
|
return phases
|
|
|
|
def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
|
|
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
|
x_rotated = x_complex * phases
|
|
x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
|
|
return x_embed
|
|
|
|
def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Args:
|
|
q (sp.SparseTensor): [..., N, D] tensor of queries
|
|
k (sp.SparseTensor): [..., N, D] tensor of keys
|
|
indices (torch.Tensor): [..., N, C] tensor of spatial positions
|
|
"""
|
|
if indices is None:
|
|
indices = torch.arange(q.shape[-2], device=q.device)
|
|
if len(q.shape) > 2:
|
|
indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,))
|
|
|
|
phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
|
|
if phases.shape[1] < self.hidden_size // 2:
|
|
phases = torch.cat([phases, torch.polar(
|
|
torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device),
|
|
torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device)
|
|
)], dim=-1)
|
|
q_embed = self._rotary_embedding(q, phases)
|
|
k_embed = self._rotary_embedding(k, phases)
|
|
return q_embed, k_embed
|
|
|
|
|
|
class MultiHeadAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
channels: int,
|
|
num_heads: int,
|
|
ctx_channels: Optional[int]=None,
|
|
type: Literal["self", "cross"] = "self",
|
|
attn_mode: Literal["full", "windowed"] = "full",
|
|
window_size: Optional[int] = None,
|
|
shift_window: Optional[Tuple[int, int, int]] = None,
|
|
qkv_bias: bool = True,
|
|
use_rope: bool = False,
|
|
qk_rms_norm: bool = False,
|
|
):
|
|
super().__init__()
|
|
assert channels % num_heads == 0
|
|
assert type in ["self", "cross"], f"Invalid attention type: {type}"
|
|
assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}"
|
|
assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
|
|
|
|
if attn_mode == "windowed":
|
|
raise NotImplementedError("Windowed attention is not yet implemented")
|
|
|
|
self.channels = channels
|
|
self.head_dim = channels // num_heads
|
|
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
|
|
self.num_heads = num_heads
|
|
self._type = type
|
|
self.attn_mode = attn_mode
|
|
self.window_size = window_size
|
|
self.shift_window = shift_window
|
|
self.use_rope = use_rope
|
|
self.qk_rms_norm = qk_rms_norm
|
|
|
|
if self._type == "self":
|
|
self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
|
|
else:
|
|
self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
|
|
self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
|
|
|
|
if self.qk_rms_norm:
|
|
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
|
|
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
|
|
|
|
self.to_out = nn.Linear(channels, channels)
|
|
|
|
if use_rope:
|
|
self.rope = RotaryPositionEmbedder(channels)
|
|
|
|
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
B, L, C = x.shape
|
|
if self._type == "self":
|
|
qkv = self.to_qkv(x)
|
|
qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
|
|
if self.use_rope:
|
|
q, k, v = qkv.unbind(dim=2)
|
|
q, k = self.rope(q, k, indices)
|
|
qkv = torch.stack([q, k, v], dim=2)
|
|
if self.attn_mode == "full":
|
|
if self.qk_rms_norm:
|
|
q, k, v = qkv.unbind(dim=2)
|
|
q = self.q_rms_norm(q)
|
|
k = self.k_rms_norm(k)
|
|
h = scaled_dot_product_attention(q, k, v)
|
|
else:
|
|
h = scaled_dot_product_attention(qkv)
|
|
elif self.attn_mode == "windowed":
|
|
raise NotImplementedError("Windowed attention is not yet implemented")
|
|
else:
|
|
Lkv = context.shape[1]
|
|
q = self.to_q(x)
|
|
kv = self.to_kv(context)
|
|
q = q.reshape(B, L, self.num_heads, -1)
|
|
kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
|
|
if self.qk_rms_norm:
|
|
q = self.q_rms_norm(q)
|
|
k, v = kv.unbind(dim=2)
|
|
k = self.k_rms_norm(k)
|
|
h = scaled_dot_product_attention(q, k, v)
|
|
else:
|
|
h = scaled_dot_product_attention(q, kv)
|
|
h = h.reshape(B, L, -1)
|
|
h = self.to_out(h)
|
|
return h
|
|
|