|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
class RobustVelocityAdapter(nn.Module): |
|
""" |
|
Fixed version: manual multi-head cross-attention emits [B, heads, Q, K] scores |
|
so that _add_rel_pos_bias can unpack them correctly. |
|
""" |
|
def __init__( |
|
self, |
|
t5_dim: int = 512, |
|
clip_dim: int = 768, |
|
hidden_dim: int = 1024, |
|
out_tokens: int = 64, |
|
self_attn_layers: int = 2, |
|
cross_heads: int = 8, |
|
max_rel_pos: int = 128, |
|
): |
|
super().__init__() |
|
self.out_tokens = out_tokens |
|
self.cross_heads = cross_heads |
|
self.head_dim = t5_dim // cross_heads |
|
self.max_rel_pos = max_rel_pos |
|
|
|
|
|
self.self_attn = nn.ModuleList() |
|
self.self_norm = nn.ModuleList() |
|
for _ in range(self_attn_layers): |
|
self.self_attn.append(nn.MultiheadAttention(t5_dim, cross_heads, batch_first=True)) |
|
self.self_norm.append(nn.LayerNorm(t5_dim)) |
|
|
|
|
|
def resblock(): |
|
return nn.Sequential( |
|
nn.LayerNorm(t5_dim), |
|
nn.Linear(t5_dim, t5_dim), |
|
nn.GELU(), |
|
nn.Linear(t5_dim, t5_dim), |
|
) |
|
self.res1 = resblock() |
|
self.res2 = resblock() |
|
|
|
|
|
self.query_pos = nn.Parameter(torch.randn(out_tokens, t5_dim)) |
|
|
|
|
|
self.anchor_proj = nn.Sequential( |
|
nn.Linear(t5_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, clip_dim) |
|
) |
|
self.delta_proj = nn.Sequential( |
|
nn.Linear(t5_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, clip_dim) |
|
) |
|
self.var_proj = nn.Sequential( |
|
nn.Linear(t5_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, clip_dim) |
|
) |
|
self.gate_proj = nn.Sequential( |
|
nn.Linear(t5_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, clip_dim), nn.Sigmoid() |
|
) |
|
|
|
|
|
self.rel_bias = nn.Parameter(torch.zeros(2*max_rel_pos-1, cross_heads)) |
|
|
|
|
|
self.cross_norm = nn.LayerNorm(t5_dim) |
|
|
|
def _add_rel_pos_bias(self, attn_scores: torch.Tensor) -> torch.Tensor: |
|
""" |
|
attn_scores: [B, heads, Q, K] |
|
returns: attn_scores + bias where bias is [B, heads, Q, K] |
|
""" |
|
B, H, Q, K = attn_scores.shape |
|
device = attn_scores.device |
|
|
|
|
|
idx_q = torch.arange(Q, device=device) |
|
idx_k = torch.arange(K, device=device) |
|
|
|
|
|
|
|
rel = idx_q.unsqueeze(1) - idx_k.unsqueeze(0) |
|
|
|
|
|
max_rel = self.max_rel_pos |
|
rel = rel.clamp(-max_rel+1, max_rel-1) + (max_rel - 1) |
|
|
|
|
|
|
|
bias = self.rel_bias[rel] |
|
bias = bias.permute(2, 0, 1) |
|
|
|
|
|
bias = bias.unsqueeze(0).expand(B, -1, -1, -1) |
|
return attn_scores + bias |
|
|
|
|
|
def forward(self, t5_seq: torch.Tensor): |
|
""" |
|
t5_seq: [B, L, t5_dim] |
|
returns: |
|
anchor: [B, out_tokens, clip_dim] |
|
delta: [B, out_tokens, clip_dim] |
|
sigma: [B, out_tokens, clip_dim] |
|
""" |
|
x = t5_seq |
|
B, L, D = x.shape |
|
|
|
|
|
for attn, norm in zip(self.self_attn, self.self_norm): |
|
res, _ = attn(x, x, x) |
|
x = norm(x + res) |
|
|
|
|
|
x = x + self.res1(x) |
|
x = x + self.res2(x) |
|
|
|
|
|
queries = self.query_pos.unsqueeze(0).expand(B, -1, -1) |
|
|
|
q = queries.view(B, self.out_tokens, self.cross_heads, self.head_dim).permute(0,2,1,3) |
|
k = x.view(B, L, self.cross_heads, self.head_dim).permute(0,2,1,3) |
|
v = k |
|
|
|
|
|
scores = (q @ k.transpose(-2,-1)) / math.sqrt(self.head_dim) |
|
scores = self._add_rel_pos_bias(scores) |
|
probs = F.softmax(scores, dim=-1) |
|
|
|
|
|
ctx = probs @ v |
|
ctx = ctx.permute(0,2,1,3).reshape(B, self.out_tokens, D) |
|
ctx = self.cross_norm(ctx) |
|
|
|
|
|
anchor = self.anchor_proj(ctx) |
|
delta_mean = self.delta_proj(ctx) |
|
delta_logvar = self.var_proj(ctx) |
|
gate = self.gate_proj(ctx) |
|
|
|
|
|
sigma = torch.exp(0.5 * delta_logvar) |
|
delta = delta_mean * gate |
|
|
|
return anchor, delta, sigma |