lorocksUMD's picture
Upload 32 files
099ac14 verified
from functools import partial
import torch
import torch.nn.functional as F
from torch.nn import ModuleList
from DenseAV.denseav.featurizers.DINO import Block
class ChannelNorm(torch.nn.Module):
def __init__(self, dim, *args, **kwargs):
super().__init__(*args, **kwargs)
self.norm = torch.nn.LayerNorm(dim, eps=1e-4)
def forward_spatial(self, x):
return self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
def forward(self, x, cls):
return self.forward_spatial(x), self.forward_cls(cls)
def forward_cls(self, cls):
if cls is not None:
return self.norm(cls)
else:
return None
def id_conv(dim, strength=.9):
conv = torch.nn.Conv2d(dim, dim, 1, padding="same")
start_w = conv.weight.data
conv.weight.data = torch.nn.Parameter(
torch.eye(dim, device=start_w.device).unsqueeze(-1).unsqueeze(-1) * strength + start_w * (1 - strength))
conv.bias.data = torch.nn.Parameter(conv.bias.data * (1 - strength))
return conv
class LinearAligner(torch.nn.Module):
def __init__(self, in_dim, out_dim, use_norm=True):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
if use_norm:
self.norm = ChannelNorm(in_dim)
else:
self.norm = Identity2()
if in_dim == out_dim:
self.layer = id_conv(in_dim, 0)
else:
self.layer = torch.nn.Conv2d(in_dim, out_dim, kernel_size=1, stride=1)
self.cls_layer = torch.nn.Linear(in_dim, out_dim)
def forward(self, spatial, cls):
norm_spatial, norm_cls = self.norm(spatial, cls)
if cls is not None:
aligned_cls = self.cls_layer(cls)
else:
aligned_cls = None
return self.layer(norm_spatial), aligned_cls
class IdLinearAligner(torch.nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
assert self.out_dim == self.in_dim
self.layer = id_conv(in_dim, 1.0)
def forward(self, spatial, cls):
return self.layer(spatial), cls
class FrequencyAvg(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, spatial, cls):
return spatial.mean(2, keepdim=True), cls
class LearnedTimePool(torch.nn.Module):
def __init__(self, dim, width, maxpool):
super().__init__()
self.dim = dim
self.width = width
self.norm = ChannelNorm(dim)
if maxpool:
self.layer = torch.nn.Sequential(
torch.nn.Conv2d(dim, dim, kernel_size=width, stride=1, padding="same"),
torch.nn.MaxPool2d(kernel_size=(1, width), stride=(1, width))
)
else:
self.layer = torch.nn.Conv2d(dim, dim, kernel_size=(1, width), stride=(1, width))
def forward(self, spatial, cls):
norm_spatial, norm_cls = self.norm(spatial, cls)
return self.layer(norm_spatial), norm_cls
class LearnedTimePool2(torch.nn.Module):
def __init__(self, in_dim, out_dim, width, maxpool, use_cls_layer):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.width = width
if maxpool:
self.layer = torch.nn.Sequential(
torch.nn.Conv2d(in_dim, out_dim, kernel_size=width, stride=1, padding="same"),
torch.nn.MaxPool2d(kernel_size=(1, width), stride=(1, width))
)
else:
self.layer = torch.nn.Conv2d(in_dim, out_dim, kernel_size=(1, width), stride=(1, width))
self.use_cls_layer = use_cls_layer
if use_cls_layer:
self.cls_layer = torch.nn.Linear(in_dim, out_dim)
def forward(self, spatial, cls):
if cls is not None:
if self.use_cls_layer:
aligned_cls = self.cls_layer(cls)
else:
aligned_cls = cls
else:
aligned_cls = None
return self.layer(spatial), aligned_cls
class Sequential2(torch.nn.Module):
def __init__(self, *modules):
super().__init__()
self.mod_list = ModuleList(modules)
def forward(self, x, y):
results = (x, y)
for m in self.mod_list:
results = m(*results)
return results
class ProgressiveGrowing(torch.nn.Module):
def __init__(self, stages, phase_lengths):
super().__init__()
self.stages = torch.nn.ModuleList(stages)
self.phase_lengths = torch.tensor(phase_lengths)
assert len(self.phase_lengths) + 1 == len(self.stages)
self.phase_boundaries = self.phase_lengths.cumsum(0)
self.register_buffer('phase', torch.tensor([1]))
def maybe_change_phase(self, global_step):
needed_phase = (global_step >= self.phase_boundaries).to(torch.int64).sum().item() + 1
if needed_phase != self.phase.item():
print(f"Changing aligner phase to {needed_phase}")
self.phase.copy_(torch.tensor([needed_phase]).to(self.phase.device))
return True
else:
return False
def parameters(self, recurse: bool = True):
phase = self.phase.item()
used_stages = self.stages[:phase]
print(f"Progressive Growing at stage {phase}")
all_params = []
for stage in used_stages:
all_params.extend(stage.parameters(recurse))
return iter(all_params)
def forward(self, spatial, cls):
pipeline = Sequential2(*self.stages[:self.phase.item()])
return pipeline(spatial, cls)
class Identity2(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return x, y
class SelfAttentionAligner(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
self.num_heads = 6
if dim % self.num_heads != 0:
self.padding = self.num_heads - (dim % self.num_heads)
else:
self.padding = 0
self.block = Block(
dim + self.padding,
num_heads=self.num_heads,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-4))
def forward(self, spatial, cls):
padded_feats = F.pad(spatial, [0, 0, 0, 0, self.padding, 0])
B, C, H, W = padded_feats.shape
proj_feats = padded_feats.reshape(B, C, H * W).permute(0, 2, 1)
if cls is not None:
assert len(cls.shape) == 2
padded_cls = F.pad(cls, [self.padding, 0])
proj_feats = torch.cat([padded_cls.unsqueeze(1), proj_feats], dim=1)
aligned_feat, attn, qkv = self.block(proj_feats, return_qkv=True)
if cls is not None:
aligned_cls = aligned_feat[:, 0, :]
aligned_spatial = aligned_feat[:, 1:, :]
else:
aligned_cls = None
aligned_spatial = aligned_feat
aligned_spatial = aligned_spatial.reshape(B, H, W, self.dim + self.padding).permute(0, 3, 1, 2)
aligned_spatial = aligned_spatial[:, self.padding:, :, :]
if aligned_cls is not None:
aligned_cls = aligned_cls[:, self.padding:]
return aligned_spatial, aligned_cls
def get_aligner(aligner_type, in_dim, out_dim, **kwargs):
if aligner_type is None:
return Identity2()
if "prog" in aligner_type:
phase_length = kwargs["phase_length"]
if aligner_type == "image_linear":
return LinearAligner(in_dim, out_dim)
elif aligner_type == "image_idlinear":
return IdLinearAligner(in_dim, out_dim)
elif aligner_type == "image_linear_no_norm":
return LinearAligner(in_dim, out_dim, use_norm=False)
elif aligner_type == "image_id":
return Identity2()
elif aligner_type == "image_norm":
return ChannelNorm(in_dim)
elif aligner_type == "audio_linear":
return Sequential2(
LinearAligner(in_dim, out_dim),
FrequencyAvg())
elif aligner_type == "audio_sa":
return Sequential2(
LinearAligner(in_dim, out_dim),
FrequencyAvg(),
SelfAttentionAligner(out_dim)
)
elif aligner_type == "audio_sa_sa":
return Sequential2(
FrequencyAvg(),
LinearAligner(in_dim, out_dim),
SelfAttentionAligner(out_dim),
SelfAttentionAligner(out_dim)
)
elif aligner_type == "audio_3_3_pool":
return Sequential2(
LinearAligner(in_dim, out_dim),
FrequencyAvg(),
LearnedTimePool(out_dim, 3, False),
LearnedTimePool(out_dim, 3, False),
)
elif aligner_type == "audio_sa_3_3_pool":
return Sequential2(
LinearAligner(in_dim, out_dim),
FrequencyAvg(),
LearnedTimePool(out_dim, 3, False),
LearnedTimePool(out_dim, 3, False),
SelfAttentionAligner(out_dim)
)
elif aligner_type == "audio_sa_3_3_pool_2":
return Sequential2(
FrequencyAvg(),
ChannelNorm(in_dim),
LearnedTimePool2(in_dim, out_dim, 3, False, True),
LearnedTimePool2(out_dim, out_dim, 3, False, False),
SelfAttentionAligner(out_dim)
)
else:
raise ValueError(f"Unknown aligner type {aligner_type}")