Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,962 Bytes
0fd2f06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import torch.nn.functional as F
from typing import Tuple
import torch
from model.base import BaseModel
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
class ODERegression(BaseModel):
def __init__(self, args, device):
"""
Initialize the ODERegression module.
This class is self-contained and compute generator losses
in the forward pass given precomputed ode solution pairs.
This class supports the ode regression loss for both causal and bidirectional models.
See Sec 4.3 of CausVid https://arxiv.org/abs/2412.07772 for details
"""
super().__init__(args, device)
# Step 1: Initialize all models
self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True)
self.generator.model.requires_grad_(True)
if getattr(args, "generator_ckpt", False):
print(f"Loading pretrained generator from {args.generator_ckpt}")
state_dict = torch.load(args.generator_ckpt, map_location="cpu")[
'generator']
self.generator.load_state_dict(
state_dict, strict=True
)
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
if self.num_frame_per_block > 1:
self.generator.model.num_frame_per_block = self.num_frame_per_block
self.independent_first_frame = getattr(args, "independent_first_frame", False)
if self.independent_first_frame:
self.generator.model.independent_first_frame = True
if args.gradient_checkpointing:
self.generator.enable_gradient_checkpointing()
# Step 2: Initialize all hyperparameters
self.timestep_shift = getattr(args, "timestep_shift", 1.0)
def _initialize_models(self, args):
self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True)
self.generator.model.requires_grad_(True)
self.text_encoder = WanTextEncoder()
self.text_encoder.requires_grad_(False)
self.vae = WanVAEWrapper()
self.vae.requires_grad_(False)
@torch.no_grad()
def _prepare_generator_input(self, ode_latent: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Given a tensor containing the whole ODE sampling trajectories,
randomly choose an intermediate timestep and return the latent as well as the corresponding timestep.
Input:
- ode_latent: a tensor containing the whole ODE sampling trajectories [batch_size, num_denoising_steps, num_frames, num_channels, height, width].
Output:
- noisy_input: a tensor containing the selected latent [batch_size, num_frames, num_channels, height, width].
- timestep: a tensor containing the corresponding timestep [batch_size].
"""
batch_size, num_denoising_steps, num_frames, num_channels, height, width = ode_latent.shape
# Step 1: Randomly choose a timestep for each frame
index = self._get_timestep(
0,
len(self.denoising_step_list),
batch_size,
num_frames,
self.num_frame_per_block,
uniform_timestep=False
)
if self.args.i2v:
index[:, 0] = len(self.denoising_step_list) - 1
noisy_input = torch.gather(
ode_latent, dim=1,
index=index.reshape(batch_size, 1, num_frames, 1, 1, 1).expand(
-1, -1, -1, num_channels, height, width).to(self.device)
).squeeze(1)
timestep = self.denoising_step_list[index].to(self.device)
# if self.extra_noise_step > 0:
# random_timestep = torch.randint(0, self.extra_noise_step, [
# batch_size, num_frames], device=self.device, dtype=torch.long)
# perturbed_noisy_input = self.scheduler.add_noise(
# noisy_input.flatten(0, 1),
# torch.randn_like(noisy_input.flatten(0, 1)),
# random_timestep.flatten(0, 1)
# ).detach().unflatten(0, (batch_size, num_frames)).type_as(noisy_input)
# noisy_input[timestep == 0] = perturbed_noisy_input[timestep == 0]
return noisy_input, timestep
def generator_loss(self, ode_latent: torch.Tensor, conditional_dict: dict) -> Tuple[torch.Tensor, dict]:
"""
Generate image/videos from noisy latents and compute the ODE regression loss.
Input:
- ode_latent: a tensor containing the ODE latents [batch_size, num_denoising_steps, num_frames, num_channels, height, width].
They are ordered from most noisy to clean latents.
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
Output:
- loss: a scalar tensor representing the generator loss.
- log_dict: a dictionary containing additional information for loss timestep breakdown.
"""
# Step 1: Run generator on noisy latents
target_latent = ode_latent[:, -1]
noisy_input, timestep = self._prepare_generator_input(
ode_latent=ode_latent)
_, pred_image_or_video = self.generator(
noisy_image_or_video=noisy_input,
conditional_dict=conditional_dict,
timestep=timestep
)
# Step 2: Compute the regression loss
mask = timestep != 0
loss = F.mse_loss(
pred_image_or_video[mask], target_latent[mask], reduction="mean")
log_dict = {
"unnormalized_loss": F.mse_loss(pred_image_or_video, target_latent, reduction='none').mean(dim=[1, 2, 3, 4]).detach(),
"timestep": timestep.float().mean(dim=1).detach(),
"input": noisy_input.detach(),
"output": pred_image_or_video.detach(),
}
return loss, log_dict
|