File size: 5,275 Bytes
e8d2a79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class RobustVelocityAdapter(nn.Module):
"""
Fixed version: manual multi-head cross-attention emits [B, heads, Q, K] scores
so that _add_rel_pos_bias can unpack them correctly.
"""
def __init__(
self,
t5_dim: int = 512,
clip_dim: int = 768,
hidden_dim: int = 1024,
out_tokens: int = 64, # now aligned with your T5 finetune
self_attn_layers: int = 2,
cross_heads: int = 8,
max_rel_pos: int = 128,
):
super().__init__()
self.out_tokens = out_tokens
self.cross_heads = cross_heads
self.head_dim = t5_dim // cross_heads
self.max_rel_pos = max_rel_pos
# 1) Self-attention stack
self.self_attn = nn.ModuleList()
self.self_norm = nn.ModuleList()
for _ in range(self_attn_layers):
self.self_attn.append(nn.MultiheadAttention(t5_dim, cross_heads, batch_first=True))
self.self_norm.append(nn.LayerNorm(t5_dim))
# 2) Residual blocks
def resblock():
return nn.Sequential(
nn.LayerNorm(t5_dim),
nn.Linear(t5_dim, t5_dim),
nn.GELU(),
nn.Linear(t5_dim, t5_dim),
)
self.res1 = resblock()
self.res2 = resblock()
# 3) Learned queries for cross-attn
self.query_pos = nn.Parameter(torch.randn(out_tokens, t5_dim))
# 4) Projection heads
self.anchor_proj = nn.Sequential(
nn.Linear(t5_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, clip_dim)
)
self.delta_proj = nn.Sequential(
nn.Linear(t5_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, clip_dim)
)
self.var_proj = nn.Sequential(
nn.Linear(t5_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, clip_dim)
)
self.gate_proj = nn.Sequential(
nn.Linear(t5_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, clip_dim), nn.Sigmoid()
)
# 5) Relative-position bias table
self.rel_bias = nn.Parameter(torch.zeros(2*max_rel_pos-1, cross_heads))
# 6) Norm after cross-attn
self.cross_norm = nn.LayerNorm(t5_dim)
def _add_rel_pos_bias(self, attn_scores: torch.Tensor) -> torch.Tensor:
"""
attn_scores: [B, heads, Q, K]
returns: attn_scores + bias where bias is [B, heads, Q, K]
"""
B, H, Q, K = attn_scores.shape
device = attn_scores.device
# 1) Query & key position indices
idx_q = torch.arange(Q, device=device) # [Q]
idx_k = torch.arange(K, device=device) # [K]
# 2) Compute relative distances for every (q, k) pair
# rel[i,j] = idx_q[i] - idx_k[j]
rel = idx_q.unsqueeze(1) - idx_k.unsqueeze(0) # [Q, K]
# 3) Clamp & shift into bias table range [0, 2*max_rel-2]
max_rel = self.max_rel_pos
rel = rel.clamp(-max_rel+1, max_rel-1) + (max_rel - 1)
# 4) Lookup per-head biases
# self.rel_bias has shape [2*max_rel-1, H]
bias = self.rel_bias[rel] # [Q, K, H]
bias = bias.permute(2, 0, 1) # [H, Q, K]
# 5) Broadcast to [B, H, Q, K] and add
bias = bias.unsqueeze(0).expand(B, -1, -1, -1)
return attn_scores + bias
def forward(self, t5_seq: torch.Tensor):
"""
t5_seq: [B, L, t5_dim]
returns:
anchor: [B, out_tokens, clip_dim]
delta: [B, out_tokens, clip_dim]
sigma: [B, out_tokens, clip_dim]
"""
x = t5_seq
B, L, D = x.shape
# 1) Self-attention + residual
for attn, norm in zip(self.self_attn, self.self_norm):
res, _ = attn(x, x, x)
x = norm(x + res)
# 2) Residual blocks
x = x + self.res1(x)
x = x + self.res2(x)
# 3) Prepare queries & split heads
queries = self.query_pos.unsqueeze(0).expand(B, -1, -1) # [B, Q, D]
# reshape into heads
q = queries.view(B, self.out_tokens, self.cross_heads, self.head_dim).permute(0,2,1,3)
k = x.view(B, L, self.cross_heads, self.head_dim).permute(0,2,1,3)
v = k
# 4) Scaled dot-product to get [B, heads, Q, K]
scores = (q @ k.transpose(-2,-1)) / math.sqrt(self.head_dim)
scores = self._add_rel_pos_bias(scores)
probs = F.softmax(scores, dim=-1) # [B, H, Q, K]
# 5) Attend & merge heads → [B, Q, D]
ctx = probs @ v # [B, H, Q, head_dim]
ctx = ctx.permute(0,2,1,3).reshape(B, self.out_tokens, D)
ctx = self.cross_norm(ctx)
# 6) Project to anchor, delta_mean, delta_logvar, gate
anchor = self.anchor_proj(ctx)
delta_mean = self.delta_proj(ctx)
delta_logvar = self.var_proj(ctx)
gate = self.gate_proj(ctx)
# 7) Compute sigma & gated delta
sigma = torch.exp(0.5 * delta_logvar)
delta = delta_mean * gate
return anchor, delta, sigma |