MotionInversion / models /unet /motion_embeddings.py
ziyangmai's picture
add error description
c9ddddb
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class MotionEmbedding(nn.Module):
def __init__(self, embed_dim: int = None, max_seq_length: int = 32, wh: int = 1):
super().__init__()
self.embed = nn.Parameter(torch.zeros(wh, max_seq_length, embed_dim))
# print('register spatial motion embedding with', wh)
self.scale = 1.0
self.trained_length = -1
def set_scale(self, scale: float):
self.scale = scale
def set_lengths(self, trained_length: int):
if trained_length > self.embed.shape[1] or trained_length <= 0:
raise ValueError("Trained length is out of bounds")
self.trained_length = trained_length
def forward(self, x):
_, seq_length, _ = x.shape # seq_length here is the target sequence length for x
# print('seq_length',seq_length)
# Assuming self.embed is [batch, frames, dim]
embeddings = self.embed[:, :seq_length] # Initial slice, may not be necessary depending on the interpolation logic
# Check if interpolation is needed
if self.trained_length != -1 and seq_length != self.trained_length:
# Interpolate embeddings to match x's sequence length
# Ensure embeddings is [batch, dim, frames] for 1D interpolation across frames
embeddings = embeddings.permute(0, 2, 1) # Now [batch, dim, frames]
embeddings = F.interpolate(embeddings, size=(seq_length,), mode='linear', align_corners=False)
embeddings = embeddings.permute(0, 2, 1) # Revert to [batch, frames, dim]
# Ensure the interpolated embeddings match the sequence length of x
if embeddings.shape[1] != seq_length:
raise ValueError(f"Interpolated embeddings sequence length {embeddings.shape[1]} does not match x's sequence length {seq_length}")
if x.shape[0] != embeddings.shape[0]:
x = x + embeddings.repeat(x.shape[0]//embeddings.shape[0],1,1) * self.scale
else:
# Now embeddings should have the shape [batch, seq_length, dim] matching x
x = x + embeddings * self.scale # Assuming broadcasting is desired over the batch and dim dimensions
return x
def forward_average(self, x):
_, seq_length, _ = x.shape # seq_length here is the target sequence length for x
# print('seq_length',seq_length)
# Assuming self.embed is [batch, frames, dim]
embeddings = self.embed[:, :seq_length] # Initial slice, may not be necessary depending on the interpolation logic
# Check if interpolation is needed
if self.trained_length != -1 and seq_length != self.trained_length:
# Interpolate embeddings to match x's sequence length
# Ensure embeddings is [batch, dim, frames] for 1D interpolation across frames
embeddings = embeddings.permute(0, 2, 1) # Now [batch, dim, frames]
embeddings = F.interpolate(embeddings, size=(seq_length,), mode='linear', align_corners=False)
embeddings = embeddings.permute(0, 2, 1) # Revert to [batch, frames, dim]
# Ensure the interpolated embeddings match the sequence length of x
if embeddings.shape[1] != seq_length:
raise ValueError(f"Interpolated embeddings sequence length {embeddings.shape[1]} does not match x's sequence length {seq_length}")
embeddings_mean = embeddings.mean(dim=1, keepdim=True)
embeddings = embeddings - embeddings_mean
if x.shape[0] != embeddings.shape[0]:
x = x + embeddings.repeat(x.shape[0]//embeddings.shape[0],1,1) * self.scale
else:
# Now embeddings should have the shape [batch, seq_length, dim] matching x
x = x + embeddings * self.scale # Assuming broadcasting is desired over the batch and dim dimensions
return x
def forward_frameSubtraction(self, x):
_, seq_length, _ = x.shape # seq_length here is the target sequence length for x
# print('seq_length',seq_length)
# Assuming self.embed is [batch, frames, dim]
embeddings = self.embed[:, :seq_length] # Initial slice, may not be necessary depending on the interpolation logic
# Check if interpolation is needed
if self.trained_length != -1 and seq_length != self.trained_length:
# Interpolate embeddings to match x's sequence length
# Ensure embeddings is [batch, dim, frames] for 1D interpolation across frames
embeddings = embeddings.permute(0, 2, 1) # Now [batch, dim, frames]
embeddings = F.interpolate(embeddings, size=(seq_length,), mode='linear', align_corners=False)
embeddings = embeddings.permute(0, 2, 1) # Revert to [batch, frames, dim]
# Ensure the interpolated embeddings match the sequence length of x
if embeddings.shape[1] != seq_length:
raise ValueError(f"Interpolated embeddings sequence length {embeddings.shape[1]} does not match x's sequence length {seq_length}")
embeddings_subtraction = embeddings[:,1:] - embeddings[:,:-1]
embeddings = embeddings.clone().detach()
embeddings[:,1:] = embeddings_subtraction
# first frame minus mean
# embeddings[:,0:1] = embeddings[:,0:1] - embeddings.mean(dim=1, keepdim=True)
if x.shape[0] != embeddings.shape[0]:
x = x + embeddings.repeat(x.shape[0]//embeddings.shape[0],1,1) * self.scale
else:
# Now embeddings should have the shape [batch, seq_length, dim] matching x
x = x + embeddings * self.scale # Assuming broadcasting is desired over the batch and dim dimensions
return x
class MotionEmbeddingSpatial(nn.Module):
def __init__(self, h: int = None, w: int = None, embed_dim: int = None, max_seq_length: int = 32):
super().__init__()
self.embed = nn.Parameter(torch.zeros(h*w, max_seq_length, embed_dim))
self.scale = 1.0
self.trained_length = -1
def set_scale(self, scale: float):
self.scale = scale
def set_lengths(self, trained_length: int):
if trained_length > self.embed.shape[1] or trained_length <= 0:
raise ValueError("Trained length is out of bounds")
self.trained_length = trained_length
def forward(self, x):
_, seq_length, _ = x.shape # seq_length here is the target sequence length for x
# Assuming self.embed is [batch, frames, dim]
embeddings = self.embed[:, :seq_length] # Initial slice, may not be necessary depending on the interpolation logic
# Check if interpolation is needed
if self.trained_length != -1 and seq_length != self.trained_length:
# Interpolate embeddings to match x's sequence length
# Ensure embeddings is [batch, dim, frames] for 1D interpolation across frames
embeddings = embeddings.permute(0, 2, 1) # Now [batch, dim, frames]
embeddings = F.interpolate(embeddings, size=(seq_length,), mode='linear', align_corners=False)
embeddings = embeddings.permute(0, 2, 1) # Revert to [batch, frames, dim]
# Ensure the interpolated embeddings match the sequence length of x
if embeddings.shape[1] != seq_length:
raise ValueError(f"Interpolated embeddings sequence length {embeddings.shape[1]} does not match x's sequence length {seq_length}")
if x.shape[0] != embeddings.shape[0]:
x = x + embeddings.repeat(x.shape[0]//embeddings.shape[0],1,1) * self.scale
else:
# Now embeddings should have the shape [batch, seq_length, dim] matching x
x = x + embeddings * self.scale # Assuming broadcasting is desired over the batch and dim dimensions
return x
def inject_motion_embeddings(model, combinations=None, config=None):
spatial_shape=np.array([config.dataset.height,config.dataset.width])
shape32 = np.ceil(spatial_shape/32).astype(int)
shape16 = np.ceil(spatial_shape/16).astype(int)
spatial_name = 'vSpatial'
replacement_dict = {}
# support for 32 frames
max_seq_length = 32
inject_layers = []
for name, module in model.named_modules():
# check if the module is temp_attention
PETemporal = '.temp_attentions.' in name
if not(PETemporal and re.search(r'transformer_blocks\.\d+$', name)):
continue
if not ([name.split('_')[0], module.norm1.normalized_shape[0]] in combinations):
continue
replacement_dict[f'{name}.pos_embed'] = MotionEmbedding(max_seq_length=max_seq_length, embed_dim=module.norm1.normalized_shape[0]).to(dtype=model.dtype, device=model.device)
replacement_keys = list(set(replacement_dict.keys()))
temp_attn_list = [name.replace('pos_embed','attn1') for name in replacement_keys] + \
[name.replace('pos_embed','attn2') for name in replacement_keys]
embed_dims = [replacement_dict[replacement_keys[i]].embed.shape[2] for i in range(len(replacement_keys))]
for temp_attn_index,temp_attn in enumerate(temp_attn_list):
place_in_net = temp_attn.split('_')[0]
pattern = r'(\d+)\.temp_attentions'
match = re.search(pattern, temp_attn)
place_in_net = temp_attn.split('_')[0]
index_in_net = match.group(1)
h,w = None,None
if place_in_net == 'up':
if index_in_net == "1":
h, w = shape32
elif index_in_net == "2":
h, w = shape16
elif place_in_net == 'down':
if index_in_net == "1":
h, w = shape16
elif index_in_net == "2":
h, w = shape32
replacement_dict[temp_attn+'.'+spatial_name] = \
MotionEmbedding(
wh=h*w,
embed_dim=embed_dims[temp_attn_index%len(replacement_keys)]
).to(dtype=model.dtype, device=model.device)
for name, new_module in replacement_dict.items():
parent_name = name.rsplit('.', 1)[0] if '.' in name else ''
module_name = name.rsplit('.', 1)[-1]
parent_module = model
if parent_name:
parent_module = dict(model.named_modules())[parent_name]
if [parent_name.split('_')[0], new_module.embed.shape[-1]] in combinations:
inject_layers.append(name)
setattr(parent_module, module_name, new_module)
inject_layers = list(set(inject_layers))
# for name in inject_layers:
# print(f"Injecting motion embedding at {name}")
parameters_list = []
for name, para in model.named_parameters():
if 'pos_embed' in name or spatial_name in name:
parameters_list.append(para)
para.requires_grad = True
else:
para.requires_grad = False
return parameters_list, inject_layers
def save_motion_embeddings(model, file_path):
# Extract motion embedding from all instances of MotionEmbedding
motion_embeddings = {
name: module.embed
for name, module in model.named_modules()
if isinstance(module, MotionEmbedding) or isinstance(module, MotionEmbeddingSpatial)
}
# Save the motion embeddings to the specified file path
torch.save(motion_embeddings, file_path)
def load_motion_embeddings(model, saved_embeddings):
for key, embedding in saved_embeddings.items():
# Extract parent module and module name from the key
parent_name = key.rsplit('.', 1)[0] if '.' in key else ''
module_name = key.rsplit('.', 1)[-1]
# Retrieve the parent module
parent_module = model
if parent_name:
parent_module = dict(model.named_modules())[parent_name]
# Create a new MotionEmbedding instance with the correct dimensions
new_module = MotionEmbedding(wh = embedding.shape[0],embed_dim=embedding.shape[-1], max_seq_length=embedding.shape[-2])
# Properly assign the loaded embeddings to the 'embed' parameter wrapped in nn.Parameter
# Ensure the embedding is on the correct device and has the correct dtype
new_module.embed = nn.Parameter(embedding.to(dtype=model.dtype, device=model.device))
# Replace the corresponding module in the model with the new MotionEmbedding instance
setattr(parent_module, module_name, new_module)
def set_motion_embedding_scale(model, scale_value):
# Iterate over all modules in the model
for _, module in model.named_modules():
# Check if the module is an instance of MotionEmbedding
if isinstance(module, MotionEmbedding):
# Set the scale attribute to the specified value
module.scale = scale_value
def set_motion_embedding_length(model, trained_length):
# Iterate over all modules in the model
for _, module in model.named_modules():
# Check if the module is an instance of MotionEmbedding
if isinstance(module, MotionEmbedding):
# Set the length to the specified value
module.trained_length = trained_length