import torch import torch.nn as nn import torch.nn.functional as F import math import torch import torch.nn as nn import torch.nn.functional as F class RobustVelocityAdapter(nn.Module): """ Fixed version: manual multi-head cross-attention emits [B, heads, Q, K] scores so that _add_rel_pos_bias can unpack them correctly. """ def __init__( self, t5_dim: int = 512, clip_dim: int = 768, hidden_dim: int = 1024, out_tokens: int = 64, # now aligned with your T5 finetune self_attn_layers: int = 2, cross_heads: int = 8, max_rel_pos: int = 128, ): super().__init__() self.out_tokens = out_tokens self.cross_heads = cross_heads self.head_dim = t5_dim // cross_heads self.max_rel_pos = max_rel_pos # 1) Self-attention stack self.self_attn = nn.ModuleList() self.self_norm = nn.ModuleList() for _ in range(self_attn_layers): self.self_attn.append(nn.MultiheadAttention(t5_dim, cross_heads, batch_first=True)) self.self_norm.append(nn.LayerNorm(t5_dim)) # 2) Residual blocks def resblock(): return nn.Sequential( nn.LayerNorm(t5_dim), nn.Linear(t5_dim, t5_dim), nn.GELU(), nn.Linear(t5_dim, t5_dim), ) self.res1 = resblock() self.res2 = resblock() # 3) Learned queries for cross-attn self.query_pos = nn.Parameter(torch.randn(out_tokens, t5_dim)) # 4) Projection heads self.anchor_proj = nn.Sequential( nn.Linear(t5_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, clip_dim) ) self.delta_proj = nn.Sequential( nn.Linear(t5_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, clip_dim) ) self.var_proj = nn.Sequential( nn.Linear(t5_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, clip_dim) ) self.gate_proj = nn.Sequential( nn.Linear(t5_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, clip_dim), nn.Sigmoid() ) # 5) Relative-position bias table self.rel_bias = nn.Parameter(torch.zeros(2*max_rel_pos-1, cross_heads)) # 6) Norm after cross-attn self.cross_norm = nn.LayerNorm(t5_dim) def _add_rel_pos_bias(self, attn_scores: torch.Tensor) -> torch.Tensor: """ attn_scores: [B, heads, Q, K] returns: attn_scores + bias where bias is [B, heads, Q, K] """ B, H, Q, K = attn_scores.shape device = attn_scores.device # 1) Query & key position indices idx_q = torch.arange(Q, device=device) # [Q] idx_k = torch.arange(K, device=device) # [K] # 2) Compute relative distances for every (q, k) pair # rel[i,j] = idx_q[i] - idx_k[j] rel = idx_q.unsqueeze(1) - idx_k.unsqueeze(0) # [Q, K] # 3) Clamp & shift into bias table range [0, 2*max_rel-2] max_rel = self.max_rel_pos rel = rel.clamp(-max_rel+1, max_rel-1) + (max_rel - 1) # 4) Lookup per-head biases # self.rel_bias has shape [2*max_rel-1, H] bias = self.rel_bias[rel] # [Q, K, H] bias = bias.permute(2, 0, 1) # [H, Q, K] # 5) Broadcast to [B, H, Q, K] and add bias = bias.unsqueeze(0).expand(B, -1, -1, -1) return attn_scores + bias def forward(self, t5_seq: torch.Tensor): """ t5_seq: [B, L, t5_dim] returns: anchor: [B, out_tokens, clip_dim] delta: [B, out_tokens, clip_dim] sigma: [B, out_tokens, clip_dim] """ x = t5_seq B, L, D = x.shape # 1) Self-attention + residual for attn, norm in zip(self.self_attn, self.self_norm): res, _ = attn(x, x, x) x = norm(x + res) # 2) Residual blocks x = x + self.res1(x) x = x + self.res2(x) # 3) Prepare queries & split heads queries = self.query_pos.unsqueeze(0).expand(B, -1, -1) # [B, Q, D] # reshape into heads q = queries.view(B, self.out_tokens, self.cross_heads, self.head_dim).permute(0,2,1,3) k = x.view(B, L, self.cross_heads, self.head_dim).permute(0,2,1,3) v = k # 4) Scaled dot-product to get [B, heads, Q, K] scores = (q @ k.transpose(-2,-1)) / math.sqrt(self.head_dim) scores = self._add_rel_pos_bias(scores) probs = F.softmax(scores, dim=-1) # [B, H, Q, K] # 5) Attend & merge heads → [B, Q, D] ctx = probs @ v # [B, H, Q, head_dim] ctx = ctx.permute(0,2,1,3).reshape(B, self.out_tokens, D) ctx = self.cross_norm(ctx) # 6) Project to anchor, delta_mean, delta_logvar, gate anchor = self.anchor_proj(ctx) delta_mean = self.delta_proj(ctx) delta_logvar = self.var_proj(ctx) gate = self.gate_proj(ctx) # 7) Compute sigma & gated delta sigma = torch.exp(0.5 * delta_logvar) delta = delta_mean * gate return anchor, delta, sigma