Spaces:
Running
Running
from functools import partial | |
import torch | |
import torch.nn.functional as F | |
from torch.nn import ModuleList | |
from DenseAV.denseav.featurizers.DINO import Block | |
class ChannelNorm(torch.nn.Module): | |
def __init__(self, dim, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.norm = torch.nn.LayerNorm(dim, eps=1e-4) | |
def forward_spatial(self, x): | |
return self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | |
def forward(self, x, cls): | |
return self.forward_spatial(x), self.forward_cls(cls) | |
def forward_cls(self, cls): | |
if cls is not None: | |
return self.norm(cls) | |
else: | |
return None | |
def id_conv(dim, strength=.9): | |
conv = torch.nn.Conv2d(dim, dim, 1, padding="same") | |
start_w = conv.weight.data | |
conv.weight.data = torch.nn.Parameter( | |
torch.eye(dim, device=start_w.device).unsqueeze(-1).unsqueeze(-1) * strength + start_w * (1 - strength)) | |
conv.bias.data = torch.nn.Parameter(conv.bias.data * (1 - strength)) | |
return conv | |
class LinearAligner(torch.nn.Module): | |
def __init__(self, in_dim, out_dim, use_norm=True): | |
super().__init__() | |
self.in_dim = in_dim | |
self.out_dim = out_dim | |
if use_norm: | |
self.norm = ChannelNorm(in_dim) | |
else: | |
self.norm = Identity2() | |
if in_dim == out_dim: | |
self.layer = id_conv(in_dim, 0) | |
else: | |
self.layer = torch.nn.Conv2d(in_dim, out_dim, kernel_size=1, stride=1) | |
self.cls_layer = torch.nn.Linear(in_dim, out_dim) | |
def forward(self, spatial, cls): | |
norm_spatial, norm_cls = self.norm(spatial, cls) | |
if cls is not None: | |
aligned_cls = self.cls_layer(cls) | |
else: | |
aligned_cls = None | |
return self.layer(norm_spatial), aligned_cls | |
class IdLinearAligner(torch.nn.Module): | |
def __init__(self, in_dim, out_dim): | |
super().__init__() | |
self.in_dim = in_dim | |
self.out_dim = out_dim | |
assert self.out_dim == self.in_dim | |
self.layer = id_conv(in_dim, 1.0) | |
def forward(self, spatial, cls): | |
return self.layer(spatial), cls | |
class FrequencyAvg(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, spatial, cls): | |
return spatial.mean(2, keepdim=True), cls | |
class LearnedTimePool(torch.nn.Module): | |
def __init__(self, dim, width, maxpool): | |
super().__init__() | |
self.dim = dim | |
self.width = width | |
self.norm = ChannelNorm(dim) | |
if maxpool: | |
self.layer = torch.nn.Sequential( | |
torch.nn.Conv2d(dim, dim, kernel_size=width, stride=1, padding="same"), | |
torch.nn.MaxPool2d(kernel_size=(1, width), stride=(1, width)) | |
) | |
else: | |
self.layer = torch.nn.Conv2d(dim, dim, kernel_size=(1, width), stride=(1, width)) | |
def forward(self, spatial, cls): | |
norm_spatial, norm_cls = self.norm(spatial, cls) | |
return self.layer(norm_spatial), norm_cls | |
class LearnedTimePool2(torch.nn.Module): | |
def __init__(self, in_dim, out_dim, width, maxpool, use_cls_layer): | |
super().__init__() | |
self.in_dim = in_dim | |
self.out_dim = out_dim | |
self.width = width | |
if maxpool: | |
self.layer = torch.nn.Sequential( | |
torch.nn.Conv2d(in_dim, out_dim, kernel_size=width, stride=1, padding="same"), | |
torch.nn.MaxPool2d(kernel_size=(1, width), stride=(1, width)) | |
) | |
else: | |
self.layer = torch.nn.Conv2d(in_dim, out_dim, kernel_size=(1, width), stride=(1, width)) | |
self.use_cls_layer = use_cls_layer | |
if use_cls_layer: | |
self.cls_layer = torch.nn.Linear(in_dim, out_dim) | |
def forward(self, spatial, cls): | |
if cls is not None: | |
if self.use_cls_layer: | |
aligned_cls = self.cls_layer(cls) | |
else: | |
aligned_cls = cls | |
else: | |
aligned_cls = None | |
return self.layer(spatial), aligned_cls | |
class Sequential2(torch.nn.Module): | |
def __init__(self, *modules): | |
super().__init__() | |
self.mod_list = ModuleList(modules) | |
def forward(self, x, y): | |
results = (x, y) | |
for m in self.mod_list: | |
results = m(*results) | |
return results | |
class ProgressiveGrowing(torch.nn.Module): | |
def __init__(self, stages, phase_lengths): | |
super().__init__() | |
self.stages = torch.nn.ModuleList(stages) | |
self.phase_lengths = torch.tensor(phase_lengths) | |
assert len(self.phase_lengths) + 1 == len(self.stages) | |
self.phase_boundaries = self.phase_lengths.cumsum(0) | |
self.register_buffer('phase', torch.tensor([1])) | |
def maybe_change_phase(self, global_step): | |
needed_phase = (global_step >= self.phase_boundaries).to(torch.int64).sum().item() + 1 | |
if needed_phase != self.phase.item(): | |
print(f"Changing aligner phase to {needed_phase}") | |
self.phase.copy_(torch.tensor([needed_phase]).to(self.phase.device)) | |
return True | |
else: | |
return False | |
def parameters(self, recurse: bool = True): | |
phase = self.phase.item() | |
used_stages = self.stages[:phase] | |
print(f"Progressive Growing at stage {phase}") | |
all_params = [] | |
for stage in used_stages: | |
all_params.extend(stage.parameters(recurse)) | |
return iter(all_params) | |
def forward(self, spatial, cls): | |
pipeline = Sequential2(*self.stages[:self.phase.item()]) | |
return pipeline(spatial, cls) | |
class Identity2(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, x, y): | |
return x, y | |
class SelfAttentionAligner(torch.nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
self.num_heads = 6 | |
if dim % self.num_heads != 0: | |
self.padding = self.num_heads - (dim % self.num_heads) | |
else: | |
self.padding = 0 | |
self.block = Block( | |
dim + self.padding, | |
num_heads=self.num_heads, | |
mlp_ratio=4, | |
qkv_bias=True, | |
qk_scale=None, | |
drop=0.0, | |
attn_drop=0.0, | |
drop_path=0.0, | |
norm_layer=partial(torch.nn.LayerNorm, eps=1e-4)) | |
def forward(self, spatial, cls): | |
padded_feats = F.pad(spatial, [0, 0, 0, 0, self.padding, 0]) | |
B, C, H, W = padded_feats.shape | |
proj_feats = padded_feats.reshape(B, C, H * W).permute(0, 2, 1) | |
if cls is not None: | |
assert len(cls.shape) == 2 | |
padded_cls = F.pad(cls, [self.padding, 0]) | |
proj_feats = torch.cat([padded_cls.unsqueeze(1), proj_feats], dim=1) | |
aligned_feat, attn, qkv = self.block(proj_feats, return_qkv=True) | |
if cls is not None: | |
aligned_cls = aligned_feat[:, 0, :] | |
aligned_spatial = aligned_feat[:, 1:, :] | |
else: | |
aligned_cls = None | |
aligned_spatial = aligned_feat | |
aligned_spatial = aligned_spatial.reshape(B, H, W, self.dim + self.padding).permute(0, 3, 1, 2) | |
aligned_spatial = aligned_spatial[:, self.padding:, :, :] | |
if aligned_cls is not None: | |
aligned_cls = aligned_cls[:, self.padding:] | |
return aligned_spatial, aligned_cls | |
def get_aligner(aligner_type, in_dim, out_dim, **kwargs): | |
if aligner_type is None: | |
return Identity2() | |
if "prog" in aligner_type: | |
phase_length = kwargs["phase_length"] | |
if aligner_type == "image_linear": | |
return LinearAligner(in_dim, out_dim) | |
elif aligner_type == "image_idlinear": | |
return IdLinearAligner(in_dim, out_dim) | |
elif aligner_type == "image_linear_no_norm": | |
return LinearAligner(in_dim, out_dim, use_norm=False) | |
elif aligner_type == "image_id": | |
return Identity2() | |
elif aligner_type == "image_norm": | |
return ChannelNorm(in_dim) | |
elif aligner_type == "audio_linear": | |
return Sequential2( | |
LinearAligner(in_dim, out_dim), | |
FrequencyAvg()) | |
elif aligner_type == "audio_sa": | |
return Sequential2( | |
LinearAligner(in_dim, out_dim), | |
FrequencyAvg(), | |
SelfAttentionAligner(out_dim) | |
) | |
elif aligner_type == "audio_sa_sa": | |
return Sequential2( | |
FrequencyAvg(), | |
LinearAligner(in_dim, out_dim), | |
SelfAttentionAligner(out_dim), | |
SelfAttentionAligner(out_dim) | |
) | |
elif aligner_type == "audio_3_3_pool": | |
return Sequential2( | |
LinearAligner(in_dim, out_dim), | |
FrequencyAvg(), | |
LearnedTimePool(out_dim, 3, False), | |
LearnedTimePool(out_dim, 3, False), | |
) | |
elif aligner_type == "audio_sa_3_3_pool": | |
return Sequential2( | |
LinearAligner(in_dim, out_dim), | |
FrequencyAvg(), | |
LearnedTimePool(out_dim, 3, False), | |
LearnedTimePool(out_dim, 3, False), | |
SelfAttentionAligner(out_dim) | |
) | |
elif aligner_type == "audio_sa_3_3_pool_2": | |
return Sequential2( | |
FrequencyAvg(), | |
ChannelNorm(in_dim), | |
LearnedTimePool2(in_dim, out_dim, 3, False, True), | |
LearnedTimePool2(out_dim, out_dim, 3, False, False), | |
SelfAttentionAligner(out_dim) | |
) | |
else: | |
raise ValueError(f"Unknown aligner type {aligner_type}") | |