|
|
|
|
|
|
|
|
|
from typing import Callable, Optional |
|
import torch |
|
from torch import nn |
|
from inspect import isfunction |
|
from einops import rearrange |
|
|
|
class AdaptiveLayerNorm1D(torch.nn.Module): |
|
""" |
|
Code modified from https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/t_cond_mlp.py#L7 |
|
""" |
|
def __init__(self, data_dim: int, norm_cond_dim: int): |
|
super().__init__() |
|
if data_dim <= 0: |
|
raise ValueError(f"data_dim must be positive, but got {data_dim}") |
|
if norm_cond_dim <= 0: |
|
raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}") |
|
self.norm = torch.nn.LayerNorm( |
|
data_dim |
|
) |
|
self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim) |
|
torch.nn.init.zeros_(self.linear.weight) |
|
torch.nn.init.zeros_(self.linear.bias) |
|
|
|
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
|
x = self.norm(x) |
|
alpha, beta = self.linear(t).chunk(2, dim=-1) |
|
|
|
|
|
if x.dim() > 2: |
|
alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1]) |
|
beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1]) |
|
|
|
return x * (1 + alpha) + beta |
|
|
|
|
|
def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1): |
|
""" |
|
Code modified from https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/t_cond_mlp.py#L48 |
|
""" |
|
if norm == "batch": |
|
return torch.nn.BatchNorm1d(dim) |
|
elif norm == "layer": |
|
return torch.nn.LayerNorm(dim) |
|
elif norm == "ada": |
|
assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}" |
|
return AdaptiveLayerNorm1D(dim, norm_cond_dim) |
|
elif norm is None: |
|
return torch.nn.Identity() |
|
else: |
|
raise ValueError(f"Unknown norm: {norm}") |
|
|
|
|
|
def exists(val): |
|
"Code modified from https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/pose_transformer.py#L17" |
|
return val is not None |
|
|
|
|
|
def default(val, d): |
|
"Code modified from https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/pose_transformer.py#L21" |
|
if exists(val): |
|
return val |
|
return d() if isfunction(d) else d |
|
|
|
|
|
class PreNorm(nn.Module): |
|
""" |
|
Code modified from https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/pose_transformer.py#L27 |
|
""" |
|
def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1): |
|
super().__init__() |
|
self.norm = normalization_layer(norm, dim, norm_cond_dim) |
|
self.fn = fn |
|
|
|
def forward(self, x: torch.Tensor, *args, **kwargs): |
|
if isinstance(self.norm, AdaptiveLayerNorm1D): |
|
return self.fn(self.norm(x, *args), **kwargs) |
|
else: |
|
return self.fn(self.norm(x), **kwargs) |
|
|
|
|
|
class FeedForward(nn.Module): |
|
""" |
|
Code modified from https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/pose_transformer.py#L40 |
|
""" |
|
def __init__(self, dim, hidden_dim, dropout=0.0): |
|
super().__init__() |
|
self.net = nn.Sequential( |
|
nn.Linear(dim, hidden_dim), |
|
nn.GELU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(hidden_dim, dim), |
|
nn.Dropout(dropout), |
|
) |
|
|
|
def forward(self, x): |
|
return self.net(x) |
|
|
|
|
|
class Attention(nn.Module): |
|
""" |
|
Code modified from https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/pose_transformer.py#L55 |
|
""" |
|
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): |
|
super().__init__() |
|
inner_dim = dim_head * heads |
|
project_out = not (heads == 1 and dim_head == dim) |
|
|
|
self.heads = heads |
|
self.scale = dim_head**-0.5 |
|
|
|
self.attend = nn.Softmax(dim=-1) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) |
|
|
|
self.to_out = ( |
|
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) |
|
if project_out |
|
else nn.Identity() |
|
) |
|
|
|
def forward(self, x, mask=None): |
|
|
|
qkv = self.to_qkv(x).chunk(3, dim=-1) |
|
|
|
|
|
|
|
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) |
|
|
|
if mask is not None: |
|
q, k, v = [x * mask[:, None, :, None] for x in [q, k, v]] |
|
|
|
|
|
|
|
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale |
|
|
|
if mask is not None: |
|
dots = dots - (1 - mask)[:, None, None, :] * 10e10 |
|
|
|
attn = self.attend(dots) |
|
|
|
if mask is not None: |
|
attn = attn * mask[:, None, None, :] |
|
|
|
attn = self.dropout(attn) |
|
|
|
out = torch.matmul(attn, v) |
|
|
|
|
|
|
|
out = rearrange(out, "b h n d -> b n (h d)") |
|
return self.to_out(out) |
|
|
|
|
|
class CrossAttention(nn.Module): |
|
"Code modified from https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/pose_transformer.py#L89" |
|
def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): |
|
super().__init__() |
|
inner_dim = dim_head * heads |
|
project_out = not (heads == 1 and dim_head == dim) |
|
|
|
self.heads = heads |
|
self.scale = dim_head**-0.5 |
|
|
|
self.attend = nn.Softmax(dim=-1) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
context_dim = default(context_dim, dim) |
|
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) |
|
self.to_q = nn.Linear(dim, inner_dim, bias=False) |
|
|
|
self.to_out = ( |
|
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) |
|
if project_out |
|
else nn.Identity() |
|
) |
|
|
|
def forward(self, x, context=None, mask=None): |
|
|
|
context = default(context, x) |
|
k, v = self.to_kv(context).chunk(2, dim=-1) |
|
q = self.to_q(x) |
|
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v]) |
|
|
|
if mask is not None: |
|
q = q * mask[:, None, :, None] |
|
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale |
|
if mask is not None: |
|
dots = dots - (1 - mask).float()[:, None, :, None] * 1e6 |
|
attn = self.attend(dots) |
|
attn = self.dropout(attn) |
|
|
|
out = torch.matmul(attn, v) |
|
|
|
if mask is not None: |
|
out = out * mask[:, None, :, None] |
|
out = rearrange(out, "b h n d -> b n (h d)") |
|
return self.to_out(out) |
|
|
|
class TransformerCrossAttn(nn.Module): |
|
"Code modified from https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/pose_transformer.py#L160" |
|
def __init__( |
|
self, |
|
dim: int, |
|
depth: int, |
|
heads: int, |
|
dim_head: int, |
|
mlp_dim: int, |
|
dropout: float = 0.0, |
|
norm: str = "layer", |
|
norm_cond_dim: int = -1, |
|
context_dim: Optional[int] = None, |
|
): |
|
super().__init__() |
|
self.layers = nn.ModuleList([]) |
|
for _ in range(depth): |
|
sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout) |
|
ca = CrossAttention( |
|
dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout |
|
) |
|
ff = FeedForward(dim, mlp_dim, dropout=dropout) |
|
self.layers.append( |
|
nn.ModuleList( |
|
[ |
|
PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim), |
|
PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim), |
|
PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim), |
|
] |
|
) |
|
) |
|
|
|
def forward(self, x: torch.Tensor, *args, context=None, context_list=None, mask=None): |
|
|
|
if context_list is None: |
|
context_list = [context] * len(self.layers) |
|
|
|
if len(context_list) != len(self.layers): |
|
raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})") |
|
|
|
for i, (self_attn, cross_attn, ff) in enumerate(self.layers): |
|
if mask is not None: |
|
try: |
|
x = x * mask[:, :, None] |
|
except: |
|
print("see ") |
|
import pdb; pdb.set_trace() |
|
x = self_attn(x, mask=mask, *args) + x |
|
x = cross_attn(x, mask=mask, *args, context=context_list[i]) + x |
|
x = ff(x, *args) + x |
|
|
|
if mask is not None: |
|
x = x * mask[:, :, None] |
|
|
|
return x |
|
|
|
class DropTokenDropout(nn.Module): |
|
"Code modified from https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/pose_transformer.py#L204" |
|
def __init__(self, p: float = 0.1): |
|
super().__init__() |
|
if p < 0 or p > 1: |
|
raise ValueError( |
|
"dropout probability has to be between 0 and 1, " "but got {}".format(p) |
|
) |
|
self.p = p |
|
|
|
def forward(self, x: torch.Tensor): |
|
|
|
if self.training and self.p > 0: |
|
zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool() |
|
|
|
if zero_mask.any(): |
|
x = x[:, ~zero_mask, :] |
|
return x |
|
|
|
|
|
class ZeroTokenDropout(nn.Module): |
|
"Code modified from https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/pose_transformer.py#L223" |
|
def __init__(self, p: float = 0.1): |
|
super().__init__() |
|
if p < 0 or p > 1: |
|
raise ValueError( |
|
"dropout probability has to be between 0 and 1, " "but got {}".format(p) |
|
) |
|
self.p = p |
|
|
|
def forward(self, x: torch.Tensor): |
|
|
|
if self.training and self.p > 0: |
|
zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool() |
|
|
|
x[zero_mask, :] = 0 |
|
return x |
|
|
|
|
|
class TransformerDecoder(nn.Module): |
|
"Code modified from https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/pose_transformer.py#L301" |
|
def __init__( |
|
self, |
|
num_tokens: int, |
|
token_dim: int, |
|
dim: int, |
|
depth: int, |
|
heads: int, |
|
mlp_dim: int, |
|
dim_head: int = 64, |
|
dropout: float = 0.0, |
|
emb_dropout: float = 0.0, |
|
emb_dropout_type: str = 'drop', |
|
norm: str = "layer", |
|
norm_cond_dim: int = -1, |
|
context_dim: Optional[int] = None, |
|
skip_token_embedding: bool = False, |
|
): |
|
super().__init__() |
|
if not skip_token_embedding: |
|
self.to_token_embedding = nn.Linear(token_dim, dim) |
|
else: |
|
self.to_token_embedding = nn.Identity() |
|
if token_dim != dim: |
|
raise ValueError( |
|
f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True" |
|
) |
|
|
|
self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim)) |
|
if emb_dropout_type == "drop": |
|
self.dropout = DropTokenDropout(emb_dropout) |
|
elif emb_dropout_type == "zero": |
|
self.dropout = ZeroTokenDropout(emb_dropout) |
|
elif emb_dropout_type == "normal": |
|
self.dropout = nn.Dropout(emb_dropout) |
|
|
|
self.transformer = TransformerCrossAttn( |
|
dim, |
|
depth, |
|
heads, |
|
dim_head, |
|
mlp_dim, |
|
dropout, |
|
norm=norm, |
|
norm_cond_dim=norm_cond_dim, |
|
context_dim=context_dim, |
|
) |
|
|
|
def forward(self, inp: torch.Tensor, *args, context=None, context_list=None, mask=None): |
|
x = self.to_token_embedding(inp) |
|
b, n, _ = x.shape |
|
|
|
x = self.dropout(x) |
|
|
|
x += self.pos_embedding[:, 0][:, None, :] |
|
x = self.transformer(x, *args, context=context, context_list=context_list, mask=mask) |
|
return x |
|
|