Spaces:
Running
on
Zero
Running
on
Zero
import re | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
class MotionEmbedding(nn.Module): | |
def __init__(self, embed_dim: int = None, max_seq_length: int = 32, wh: int = 1): | |
super().__init__() | |
self.embed = nn.Parameter(torch.zeros(wh, max_seq_length, embed_dim)) | |
# print('register spatial motion embedding with', wh) | |
self.scale = 1.0 | |
self.trained_length = -1 | |
def set_scale(self, scale: float): | |
self.scale = scale | |
def set_lengths(self, trained_length: int): | |
if trained_length > self.embed.shape[1] or trained_length <= 0: | |
raise ValueError("Trained length is out of bounds") | |
self.trained_length = trained_length | |
def forward(self, x): | |
_, seq_length, _ = x.shape # seq_length here is the target sequence length for x | |
# print('seq_length',seq_length) | |
# Assuming self.embed is [batch, frames, dim] | |
embeddings = self.embed[:, :seq_length] # Initial slice, may not be necessary depending on the interpolation logic | |
# Check if interpolation is needed | |
if self.trained_length != -1 and seq_length != self.trained_length: | |
# Interpolate embeddings to match x's sequence length | |
# Ensure embeddings is [batch, dim, frames] for 1D interpolation across frames | |
embeddings = embeddings.permute(0, 2, 1) # Now [batch, dim, frames] | |
embeddings = F.interpolate(embeddings, size=(seq_length,), mode='linear', align_corners=False) | |
embeddings = embeddings.permute(0, 2, 1) # Revert to [batch, frames, dim] | |
# Ensure the interpolated embeddings match the sequence length of x | |
if embeddings.shape[1] != seq_length: | |
raise ValueError(f"Interpolated embeddings sequence length {embeddings.shape[1]} does not match x's sequence length {seq_length}") | |
if x.shape[0] != embeddings.shape[0]: | |
x = x + embeddings.repeat(x.shape[0]//embeddings.shape[0],1,1) * self.scale | |
else: | |
# Now embeddings should have the shape [batch, seq_length, dim] matching x | |
x = x + embeddings * self.scale # Assuming broadcasting is desired over the batch and dim dimensions | |
return x | |
def forward_average(self, x): | |
_, seq_length, _ = x.shape # seq_length here is the target sequence length for x | |
# print('seq_length',seq_length) | |
# Assuming self.embed is [batch, frames, dim] | |
embeddings = self.embed[:, :seq_length] # Initial slice, may not be necessary depending on the interpolation logic | |
# Check if interpolation is needed | |
if self.trained_length != -1 and seq_length != self.trained_length: | |
# Interpolate embeddings to match x's sequence length | |
# Ensure embeddings is [batch, dim, frames] for 1D interpolation across frames | |
embeddings = embeddings.permute(0, 2, 1) # Now [batch, dim, frames] | |
embeddings = F.interpolate(embeddings, size=(seq_length,), mode='linear', align_corners=False) | |
embeddings = embeddings.permute(0, 2, 1) # Revert to [batch, frames, dim] | |
# Ensure the interpolated embeddings match the sequence length of x | |
if embeddings.shape[1] != seq_length: | |
raise ValueError(f"Interpolated embeddings sequence length {embeddings.shape[1]} does not match x's sequence length {seq_length}") | |
embeddings_mean = embeddings.mean(dim=1, keepdim=True) | |
embeddings = embeddings - embeddings_mean | |
if x.shape[0] != embeddings.shape[0]: | |
x = x + embeddings.repeat(x.shape[0]//embeddings.shape[0],1,1) * self.scale | |
else: | |
# Now embeddings should have the shape [batch, seq_length, dim] matching x | |
x = x + embeddings * self.scale # Assuming broadcasting is desired over the batch and dim dimensions | |
return x | |
def forward_frameSubtraction(self, x): | |
_, seq_length, _ = x.shape # seq_length here is the target sequence length for x | |
# print('seq_length',seq_length) | |
# Assuming self.embed is [batch, frames, dim] | |
embeddings = self.embed[:, :seq_length] # Initial slice, may not be necessary depending on the interpolation logic | |
# Check if interpolation is needed | |
if self.trained_length != -1 and seq_length != self.trained_length: | |
# Interpolate embeddings to match x's sequence length | |
# Ensure embeddings is [batch, dim, frames] for 1D interpolation across frames | |
embeddings = embeddings.permute(0, 2, 1) # Now [batch, dim, frames] | |
embeddings = F.interpolate(embeddings, size=(seq_length,), mode='linear', align_corners=False) | |
embeddings = embeddings.permute(0, 2, 1) # Revert to [batch, frames, dim] | |
# Ensure the interpolated embeddings match the sequence length of x | |
if embeddings.shape[1] != seq_length: | |
raise ValueError(f"Interpolated embeddings sequence length {embeddings.shape[1]} does not match x's sequence length {seq_length}") | |
embeddings_subtraction = embeddings[:,1:] - embeddings[:,:-1] | |
embeddings = embeddings.clone().detach() | |
embeddings[:,1:] = embeddings_subtraction | |
# first frame minus mean | |
# embeddings[:,0:1] = embeddings[:,0:1] - embeddings.mean(dim=1, keepdim=True) | |
if x.shape[0] != embeddings.shape[0]: | |
x = x + embeddings.repeat(x.shape[0]//embeddings.shape[0],1,1) * self.scale | |
else: | |
# Now embeddings should have the shape [batch, seq_length, dim] matching x | |
x = x + embeddings * self.scale # Assuming broadcasting is desired over the batch and dim dimensions | |
return x | |
class MotionEmbeddingSpatial(nn.Module): | |
def __init__(self, h: int = None, w: int = None, embed_dim: int = None, max_seq_length: int = 32): | |
super().__init__() | |
self.embed = nn.Parameter(torch.zeros(h*w, max_seq_length, embed_dim)) | |
self.scale = 1.0 | |
self.trained_length = -1 | |
def set_scale(self, scale: float): | |
self.scale = scale | |
def set_lengths(self, trained_length: int): | |
if trained_length > self.embed.shape[1] or trained_length <= 0: | |
raise ValueError("Trained length is out of bounds") | |
self.trained_length = trained_length | |
def forward(self, x): | |
_, seq_length, _ = x.shape # seq_length here is the target sequence length for x | |
# Assuming self.embed is [batch, frames, dim] | |
embeddings = self.embed[:, :seq_length] # Initial slice, may not be necessary depending on the interpolation logic | |
# Check if interpolation is needed | |
if self.trained_length != -1 and seq_length != self.trained_length: | |
# Interpolate embeddings to match x's sequence length | |
# Ensure embeddings is [batch, dim, frames] for 1D interpolation across frames | |
embeddings = embeddings.permute(0, 2, 1) # Now [batch, dim, frames] | |
embeddings = F.interpolate(embeddings, size=(seq_length,), mode='linear', align_corners=False) | |
embeddings = embeddings.permute(0, 2, 1) # Revert to [batch, frames, dim] | |
# Ensure the interpolated embeddings match the sequence length of x | |
if embeddings.shape[1] != seq_length: | |
raise ValueError(f"Interpolated embeddings sequence length {embeddings.shape[1]} does not match x's sequence length {seq_length}") | |
if x.shape[0] != embeddings.shape[0]: | |
x = x + embeddings.repeat(x.shape[0]//embeddings.shape[0],1,1) * self.scale | |
else: | |
# Now embeddings should have the shape [batch, seq_length, dim] matching x | |
x = x + embeddings * self.scale # Assuming broadcasting is desired over the batch and dim dimensions | |
return x | |
def inject_motion_embeddings(model, combinations=None, config=None): | |
spatial_shape=np.array([config.dataset.height,config.dataset.width]) | |
shape32 = np.ceil(spatial_shape/32).astype(int) | |
shape16 = np.ceil(spatial_shape/16).astype(int) | |
spatial_name = 'vSpatial' | |
replacement_dict = {} | |
# support for 32 frames | |
max_seq_length = 32 | |
inject_layers = [] | |
for name, module in model.named_modules(): | |
# check if the module is temp_attention | |
PETemporal = '.temp_attentions.' in name | |
if not(PETemporal and re.search(r'transformer_blocks\.\d+$', name)): | |
continue | |
if not ([name.split('_')[0], module.norm1.normalized_shape[0]] in combinations): | |
continue | |
replacement_dict[f'{name}.pos_embed'] = MotionEmbedding(max_seq_length=max_seq_length, embed_dim=module.norm1.normalized_shape[0]).to(dtype=model.dtype, device=model.device) | |
replacement_keys = list(set(replacement_dict.keys())) | |
temp_attn_list = [name.replace('pos_embed','attn1') for name in replacement_keys] + \ | |
[name.replace('pos_embed','attn2') for name in replacement_keys] | |
embed_dims = [replacement_dict[replacement_keys[i]].embed.shape[2] for i in range(len(replacement_keys))] | |
for temp_attn_index,temp_attn in enumerate(temp_attn_list): | |
place_in_net = temp_attn.split('_')[0] | |
pattern = r'(\d+)\.temp_attentions' | |
match = re.search(pattern, temp_attn) | |
place_in_net = temp_attn.split('_')[0] | |
index_in_net = match.group(1) | |
h,w = None,None | |
if place_in_net == 'up': | |
if index_in_net == "1": | |
h, w = shape32 | |
elif index_in_net == "2": | |
h, w = shape16 | |
elif place_in_net == 'down': | |
if index_in_net == "1": | |
h, w = shape16 | |
elif index_in_net == "2": | |
h, w = shape32 | |
replacement_dict[temp_attn+'.'+spatial_name] = \ | |
MotionEmbedding( | |
wh=h*w, | |
embed_dim=embed_dims[temp_attn_index%len(replacement_keys)] | |
).to(dtype=model.dtype, device=model.device) | |
for name, new_module in replacement_dict.items(): | |
parent_name = name.rsplit('.', 1)[0] if '.' in name else '' | |
module_name = name.rsplit('.', 1)[-1] | |
parent_module = model | |
if parent_name: | |
parent_module = dict(model.named_modules())[parent_name] | |
if [parent_name.split('_')[0], new_module.embed.shape[-1]] in combinations: | |
inject_layers.append(name) | |
setattr(parent_module, module_name, new_module) | |
inject_layers = list(set(inject_layers)) | |
# for name in inject_layers: | |
# print(f"Injecting motion embedding at {name}") | |
parameters_list = [] | |
for name, para in model.named_parameters(): | |
if 'pos_embed' in name or spatial_name in name: | |
parameters_list.append(para) | |
para.requires_grad = True | |
else: | |
para.requires_grad = False | |
return parameters_list, inject_layers | |
def save_motion_embeddings(model, file_path): | |
# Extract motion embedding from all instances of MotionEmbedding | |
motion_embeddings = { | |
name: module.embed | |
for name, module in model.named_modules() | |
if isinstance(module, MotionEmbedding) or isinstance(module, MotionEmbeddingSpatial) | |
} | |
# Save the motion embeddings to the specified file path | |
torch.save(motion_embeddings, file_path) | |
def load_motion_embeddings(model, saved_embeddings): | |
for key, embedding in saved_embeddings.items(): | |
# Extract parent module and module name from the key | |
parent_name = key.rsplit('.', 1)[0] if '.' in key else '' | |
module_name = key.rsplit('.', 1)[-1] | |
# Retrieve the parent module | |
parent_module = model | |
if parent_name: | |
parent_module = dict(model.named_modules())[parent_name] | |
# Create a new MotionEmbedding instance with the correct dimensions | |
new_module = MotionEmbedding(wh = embedding.shape[0],embed_dim=embedding.shape[-1], max_seq_length=embedding.shape[-2]) | |
# Properly assign the loaded embeddings to the 'embed' parameter wrapped in nn.Parameter | |
# Ensure the embedding is on the correct device and has the correct dtype | |
new_module.embed = nn.Parameter(embedding.to(dtype=model.dtype, device=model.device)) | |
# Replace the corresponding module in the model with the new MotionEmbedding instance | |
setattr(parent_module, module_name, new_module) | |
def set_motion_embedding_scale(model, scale_value): | |
# Iterate over all modules in the model | |
for _, module in model.named_modules(): | |
# Check if the module is an instance of MotionEmbedding | |
if isinstance(module, MotionEmbedding): | |
# Set the scale attribute to the specified value | |
module.scale = scale_value | |
def set_motion_embedding_length(model, trained_length): | |
# Iterate over all modules in the model | |
for _, module in model.named_modules(): | |
# Check if the module is an instance of MotionEmbedding | |
if isinstance(module, MotionEmbedding): | |
# Set the length to the specified value | |
module.trained_length = trained_length | |