robust-velocity-adapter / robust_velocity_adapter.py
AbstractPhil's picture
Create robust_velocity_adapter.py
e8d2a79 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class RobustVelocityAdapter(nn.Module):
"""
Fixed version: manual multi-head cross-attention emits [B, heads, Q, K] scores
so that _add_rel_pos_bias can unpack them correctly.
"""
def __init__(
self,
t5_dim: int = 512,
clip_dim: int = 768,
hidden_dim: int = 1024,
out_tokens: int = 64, # now aligned with your T5 finetune
self_attn_layers: int = 2,
cross_heads: int = 8,
max_rel_pos: int = 128,
):
super().__init__()
self.out_tokens = out_tokens
self.cross_heads = cross_heads
self.head_dim = t5_dim // cross_heads
self.max_rel_pos = max_rel_pos
# 1) Self-attention stack
self.self_attn = nn.ModuleList()
self.self_norm = nn.ModuleList()
for _ in range(self_attn_layers):
self.self_attn.append(nn.MultiheadAttention(t5_dim, cross_heads, batch_first=True))
self.self_norm.append(nn.LayerNorm(t5_dim))
# 2) Residual blocks
def resblock():
return nn.Sequential(
nn.LayerNorm(t5_dim),
nn.Linear(t5_dim, t5_dim),
nn.GELU(),
nn.Linear(t5_dim, t5_dim),
)
self.res1 = resblock()
self.res2 = resblock()
# 3) Learned queries for cross-attn
self.query_pos = nn.Parameter(torch.randn(out_tokens, t5_dim))
# 4) Projection heads
self.anchor_proj = nn.Sequential(
nn.Linear(t5_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, clip_dim)
)
self.delta_proj = nn.Sequential(
nn.Linear(t5_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, clip_dim)
)
self.var_proj = nn.Sequential(
nn.Linear(t5_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, clip_dim)
)
self.gate_proj = nn.Sequential(
nn.Linear(t5_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, clip_dim), nn.Sigmoid()
)
# 5) Relative-position bias table
self.rel_bias = nn.Parameter(torch.zeros(2*max_rel_pos-1, cross_heads))
# 6) Norm after cross-attn
self.cross_norm = nn.LayerNorm(t5_dim)
def _add_rel_pos_bias(self, attn_scores: torch.Tensor) -> torch.Tensor:
"""
attn_scores: [B, heads, Q, K]
returns: attn_scores + bias where bias is [B, heads, Q, K]
"""
B, H, Q, K = attn_scores.shape
device = attn_scores.device
# 1) Query & key position indices
idx_q = torch.arange(Q, device=device) # [Q]
idx_k = torch.arange(K, device=device) # [K]
# 2) Compute relative distances for every (q, k) pair
# rel[i,j] = idx_q[i] - idx_k[j]
rel = idx_q.unsqueeze(1) - idx_k.unsqueeze(0) # [Q, K]
# 3) Clamp & shift into bias table range [0, 2*max_rel-2]
max_rel = self.max_rel_pos
rel = rel.clamp(-max_rel+1, max_rel-1) + (max_rel - 1)
# 4) Lookup per-head biases
# self.rel_bias has shape [2*max_rel-1, H]
bias = self.rel_bias[rel] # [Q, K, H]
bias = bias.permute(2, 0, 1) # [H, Q, K]
# 5) Broadcast to [B, H, Q, K] and add
bias = bias.unsqueeze(0).expand(B, -1, -1, -1)
return attn_scores + bias
def forward(self, t5_seq: torch.Tensor):
"""
t5_seq: [B, L, t5_dim]
returns:
anchor: [B, out_tokens, clip_dim]
delta: [B, out_tokens, clip_dim]
sigma: [B, out_tokens, clip_dim]
"""
x = t5_seq
B, L, D = x.shape
# 1) Self-attention + residual
for attn, norm in zip(self.self_attn, self.self_norm):
res, _ = attn(x, x, x)
x = norm(x + res)
# 2) Residual blocks
x = x + self.res1(x)
x = x + self.res2(x)
# 3) Prepare queries & split heads
queries = self.query_pos.unsqueeze(0).expand(B, -1, -1) # [B, Q, D]
# reshape into heads
q = queries.view(B, self.out_tokens, self.cross_heads, self.head_dim).permute(0,2,1,3)
k = x.view(B, L, self.cross_heads, self.head_dim).permute(0,2,1,3)
v = k
# 4) Scaled dot-product to get [B, heads, Q, K]
scores = (q @ k.transpose(-2,-1)) / math.sqrt(self.head_dim)
scores = self._add_rel_pos_bias(scores)
probs = F.softmax(scores, dim=-1) # [B, H, Q, K]
# 5) Attend & merge heads → [B, Q, D]
ctx = probs @ v # [B, H, Q, head_dim]
ctx = ctx.permute(0,2,1,3).reshape(B, self.out_tokens, D)
ctx = self.cross_norm(ctx)
# 6) Project to anchor, delta_mean, delta_logvar, gate
anchor = self.anchor_proj(ctx)
delta_mean = self.delta_proj(ctx)
delta_logvar = self.var_proj(ctx)
gate = self.gate_proj(ctx)
# 7) Compute sigma & gated delta
sigma = torch.exp(0.5 * delta_logvar)
delta = delta_mean * gate
return anchor, delta, sigma