self-forcing / model /ode_regression.py
multimodalart's picture
Upload 80 files
0fd2f06 verified
import torch.nn.functional as F
from typing import Tuple
import torch
from model.base import BaseModel
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
class ODERegression(BaseModel):
def __init__(self, args, device):
"""
Initialize the ODERegression module.
This class is self-contained and compute generator losses
in the forward pass given precomputed ode solution pairs.
This class supports the ode regression loss for both causal and bidirectional models.
See Sec 4.3 of CausVid https://arxiv.org/abs/2412.07772 for details
"""
super().__init__(args, device)
# Step 1: Initialize all models
self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True)
self.generator.model.requires_grad_(True)
if getattr(args, "generator_ckpt", False):
print(f"Loading pretrained generator from {args.generator_ckpt}")
state_dict = torch.load(args.generator_ckpt, map_location="cpu")[
'generator']
self.generator.load_state_dict(
state_dict, strict=True
)
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
if self.num_frame_per_block > 1:
self.generator.model.num_frame_per_block = self.num_frame_per_block
self.independent_first_frame = getattr(args, "independent_first_frame", False)
if self.independent_first_frame:
self.generator.model.independent_first_frame = True
if args.gradient_checkpointing:
self.generator.enable_gradient_checkpointing()
# Step 2: Initialize all hyperparameters
self.timestep_shift = getattr(args, "timestep_shift", 1.0)
def _initialize_models(self, args):
self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True)
self.generator.model.requires_grad_(True)
self.text_encoder = WanTextEncoder()
self.text_encoder.requires_grad_(False)
self.vae = WanVAEWrapper()
self.vae.requires_grad_(False)
@torch.no_grad()
def _prepare_generator_input(self, ode_latent: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Given a tensor containing the whole ODE sampling trajectories,
randomly choose an intermediate timestep and return the latent as well as the corresponding timestep.
Input:
- ode_latent: a tensor containing the whole ODE sampling trajectories [batch_size, num_denoising_steps, num_frames, num_channels, height, width].
Output:
- noisy_input: a tensor containing the selected latent [batch_size, num_frames, num_channels, height, width].
- timestep: a tensor containing the corresponding timestep [batch_size].
"""
batch_size, num_denoising_steps, num_frames, num_channels, height, width = ode_latent.shape
# Step 1: Randomly choose a timestep for each frame
index = self._get_timestep(
0,
len(self.denoising_step_list),
batch_size,
num_frames,
self.num_frame_per_block,
uniform_timestep=False
)
if self.args.i2v:
index[:, 0] = len(self.denoising_step_list) - 1
noisy_input = torch.gather(
ode_latent, dim=1,
index=index.reshape(batch_size, 1, num_frames, 1, 1, 1).expand(
-1, -1, -1, num_channels, height, width).to(self.device)
).squeeze(1)
timestep = self.denoising_step_list[index].to(self.device)
# if self.extra_noise_step > 0:
# random_timestep = torch.randint(0, self.extra_noise_step, [
# batch_size, num_frames], device=self.device, dtype=torch.long)
# perturbed_noisy_input = self.scheduler.add_noise(
# noisy_input.flatten(0, 1),
# torch.randn_like(noisy_input.flatten(0, 1)),
# random_timestep.flatten(0, 1)
# ).detach().unflatten(0, (batch_size, num_frames)).type_as(noisy_input)
# noisy_input[timestep == 0] = perturbed_noisy_input[timestep == 0]
return noisy_input, timestep
def generator_loss(self, ode_latent: torch.Tensor, conditional_dict: dict) -> Tuple[torch.Tensor, dict]:
"""
Generate image/videos from noisy latents and compute the ODE regression loss.
Input:
- ode_latent: a tensor containing the ODE latents [batch_size, num_denoising_steps, num_frames, num_channels, height, width].
They are ordered from most noisy to clean latents.
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
Output:
- loss: a scalar tensor representing the generator loss.
- log_dict: a dictionary containing additional information for loss timestep breakdown.
"""
# Step 1: Run generator on noisy latents
target_latent = ode_latent[:, -1]
noisy_input, timestep = self._prepare_generator_input(
ode_latent=ode_latent)
_, pred_image_or_video = self.generator(
noisy_image_or_video=noisy_input,
conditional_dict=conditional_dict,
timestep=timestep
)
# Step 2: Compute the regression loss
mask = timestep != 0
loss = F.mse_loss(
pred_image_or_video[mask], target_latent[mask], reduction="mean")
log_dict = {
"unnormalized_loss": F.mse_loss(pred_image_or_video, target_latent, reduction='none').mean(dim=[1, 2, 3, 4]).detach(),
"timestep": timestep.float().mean(dim=1).detach(),
"input": noisy_input.detach(),
"output": pred_image_or_video.detach(),
}
return loss, log_dict