Spaces:
Running
on
Zero
Running
on
Zero
import torch.nn.functional as F | |
from typing import Tuple | |
import torch | |
from model.base import BaseModel | |
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper | |
class ODERegression(BaseModel): | |
def __init__(self, args, device): | |
""" | |
Initialize the ODERegression module. | |
This class is self-contained and compute generator losses | |
in the forward pass given precomputed ode solution pairs. | |
This class supports the ode regression loss for both causal and bidirectional models. | |
See Sec 4.3 of CausVid https://arxiv.org/abs/2412.07772 for details | |
""" | |
super().__init__(args, device) | |
# Step 1: Initialize all models | |
self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True) | |
self.generator.model.requires_grad_(True) | |
if getattr(args, "generator_ckpt", False): | |
print(f"Loading pretrained generator from {args.generator_ckpt}") | |
state_dict = torch.load(args.generator_ckpt, map_location="cpu")[ | |
'generator'] | |
self.generator.load_state_dict( | |
state_dict, strict=True | |
) | |
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1) | |
if self.num_frame_per_block > 1: | |
self.generator.model.num_frame_per_block = self.num_frame_per_block | |
self.independent_first_frame = getattr(args, "independent_first_frame", False) | |
if self.independent_first_frame: | |
self.generator.model.independent_first_frame = True | |
if args.gradient_checkpointing: | |
self.generator.enable_gradient_checkpointing() | |
# Step 2: Initialize all hyperparameters | |
self.timestep_shift = getattr(args, "timestep_shift", 1.0) | |
def _initialize_models(self, args): | |
self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True) | |
self.generator.model.requires_grad_(True) | |
self.text_encoder = WanTextEncoder() | |
self.text_encoder.requires_grad_(False) | |
self.vae = WanVAEWrapper() | |
self.vae.requires_grad_(False) | |
def _prepare_generator_input(self, ode_latent: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Given a tensor containing the whole ODE sampling trajectories, | |
randomly choose an intermediate timestep and return the latent as well as the corresponding timestep. | |
Input: | |
- ode_latent: a tensor containing the whole ODE sampling trajectories [batch_size, num_denoising_steps, num_frames, num_channels, height, width]. | |
Output: | |
- noisy_input: a tensor containing the selected latent [batch_size, num_frames, num_channels, height, width]. | |
- timestep: a tensor containing the corresponding timestep [batch_size]. | |
""" | |
batch_size, num_denoising_steps, num_frames, num_channels, height, width = ode_latent.shape | |
# Step 1: Randomly choose a timestep for each frame | |
index = self._get_timestep( | |
0, | |
len(self.denoising_step_list), | |
batch_size, | |
num_frames, | |
self.num_frame_per_block, | |
uniform_timestep=False | |
) | |
if self.args.i2v: | |
index[:, 0] = len(self.denoising_step_list) - 1 | |
noisy_input = torch.gather( | |
ode_latent, dim=1, | |
index=index.reshape(batch_size, 1, num_frames, 1, 1, 1).expand( | |
-1, -1, -1, num_channels, height, width).to(self.device) | |
).squeeze(1) | |
timestep = self.denoising_step_list[index].to(self.device) | |
# if self.extra_noise_step > 0: | |
# random_timestep = torch.randint(0, self.extra_noise_step, [ | |
# batch_size, num_frames], device=self.device, dtype=torch.long) | |
# perturbed_noisy_input = self.scheduler.add_noise( | |
# noisy_input.flatten(0, 1), | |
# torch.randn_like(noisy_input.flatten(0, 1)), | |
# random_timestep.flatten(0, 1) | |
# ).detach().unflatten(0, (batch_size, num_frames)).type_as(noisy_input) | |
# noisy_input[timestep == 0] = perturbed_noisy_input[timestep == 0] | |
return noisy_input, timestep | |
def generator_loss(self, ode_latent: torch.Tensor, conditional_dict: dict) -> Tuple[torch.Tensor, dict]: | |
""" | |
Generate image/videos from noisy latents and compute the ODE regression loss. | |
Input: | |
- ode_latent: a tensor containing the ODE latents [batch_size, num_denoising_steps, num_frames, num_channels, height, width]. | |
They are ordered from most noisy to clean latents. | |
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings). | |
Output: | |
- loss: a scalar tensor representing the generator loss. | |
- log_dict: a dictionary containing additional information for loss timestep breakdown. | |
""" | |
# Step 1: Run generator on noisy latents | |
target_latent = ode_latent[:, -1] | |
noisy_input, timestep = self._prepare_generator_input( | |
ode_latent=ode_latent) | |
_, pred_image_or_video = self.generator( | |
noisy_image_or_video=noisy_input, | |
conditional_dict=conditional_dict, | |
timestep=timestep | |
) | |
# Step 2: Compute the regression loss | |
mask = timestep != 0 | |
loss = F.mse_loss( | |
pred_image_or_video[mask], target_latent[mask], reduction="mean") | |
log_dict = { | |
"unnormalized_loss": F.mse_loss(pred_image_or_video, target_latent, reduction='none').mean(dim=[1, 2, 3, 4]).detach(), | |
"timestep": timestep.float().mean(dim=1).detach(), | |
"input": noisy_input.detach(), | |
"output": pred_image_or_video.detach(), | |
} | |
return loss, log_dict | |