Spaces:
Running
on
Zero
Running
on
Zero
from typing import Tuple | |
from einops import rearrange | |
from torch import nn | |
import torch.distributed as dist | |
import torch | |
from pipeline import SelfForcingTrainingPipeline | |
from utils.loss import get_denoising_loss | |
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper | |
class BaseModel(nn.Module): | |
def __init__(self, args, device): | |
super().__init__() | |
self._initialize_models(args, device) | |
self.device = device | |
self.args = args | |
self.dtype = torch.bfloat16 if args.mixed_precision else torch.float32 | |
if hasattr(args, "denoising_step_list"): | |
self.denoising_step_list = torch.tensor(args.denoising_step_list, dtype=torch.long) | |
if args.warp_denoising_step: | |
timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32))) | |
self.denoising_step_list = timesteps[1000 - self.denoising_step_list] | |
def _initialize_models(self, args, device): | |
self.real_model_name = getattr(args, "real_name", "Wan2.1-T2V-1.3B") | |
self.fake_model_name = getattr(args, "fake_name", "Wan2.1-T2V-1.3B") | |
self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True) | |
self.generator.model.requires_grad_(True) | |
self.real_score = WanDiffusionWrapper(model_name=self.real_model_name, is_causal=False) | |
self.real_score.model.requires_grad_(False) | |
self.fake_score = WanDiffusionWrapper(model_name=self.fake_model_name, is_causal=False) | |
self.fake_score.model.requires_grad_(True) | |
self.text_encoder = WanTextEncoder() | |
self.text_encoder.requires_grad_(False) | |
self.vae = WanVAEWrapper() | |
self.vae.requires_grad_(False) | |
self.scheduler = self.generator.get_scheduler() | |
self.scheduler.timesteps = self.scheduler.timesteps.to(device) | |
def _get_timestep( | |
self, | |
min_timestep: int, | |
max_timestep: int, | |
batch_size: int, | |
num_frame: int, | |
num_frame_per_block: int, | |
uniform_timestep: bool = False | |
) -> torch.Tensor: | |
""" | |
Randomly generate a timestep tensor based on the generator's task type. It uniformly samples a timestep | |
from the range [min_timestep, max_timestep], and returns a tensor of shape [batch_size, num_frame]. | |
- If uniform_timestep, it will use the same timestep for all frames. | |
- If not uniform_timestep, it will use a different timestep for each block. | |
""" | |
if uniform_timestep: | |
timestep = torch.randint( | |
min_timestep, | |
max_timestep, | |
[batch_size, 1], | |
device=self.device, | |
dtype=torch.long | |
).repeat(1, num_frame) | |
return timestep | |
else: | |
timestep = torch.randint( | |
min_timestep, | |
max_timestep, | |
[batch_size, num_frame], | |
device=self.device, | |
dtype=torch.long | |
) | |
# make the noise level the same within every block | |
if self.independent_first_frame: | |
# the first frame is always kept the same | |
timestep_from_second = timestep[:, 1:] | |
timestep_from_second = timestep_from_second.reshape( | |
timestep_from_second.shape[0], -1, num_frame_per_block) | |
timestep_from_second[:, :, 1:] = timestep_from_second[:, :, 0:1] | |
timestep_from_second = timestep_from_second.reshape( | |
timestep_from_second.shape[0], -1) | |
timestep = torch.cat([timestep[:, 0:1], timestep_from_second], dim=1) | |
else: | |
timestep = timestep.reshape( | |
timestep.shape[0], -1, num_frame_per_block) | |
timestep[:, :, 1:] = timestep[:, :, 0:1] | |
timestep = timestep.reshape(timestep.shape[0], -1) | |
return timestep | |
class SelfForcingModel(BaseModel): | |
def __init__(self, args, device): | |
super().__init__(args, device) | |
self.denoising_loss_func = get_denoising_loss(args.denoising_loss_type)() | |
def _run_generator( | |
self, | |
image_or_video_shape, | |
conditional_dict: dict, | |
initial_latent: torch.tensor = None | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Optionally simulate the generator's input from noise using backward simulation | |
and then run the generator for one-step. | |
Input: | |
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W]. | |
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings). | |
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings). | |
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used. | |
- initial_latent: a tensor containing the initial latents [B, F, C, H, W]. | |
Output: | |
- pred_image: a tensor with shape [B, F, C, H, W]. | |
- denoised_timestep: an integer | |
""" | |
# Step 1: Sample noise and backward simulate the generator's input | |
assert getattr(self.args, "backward_simulation", True), "Backward simulation needs to be enabled" | |
if initial_latent is not None: | |
conditional_dict["initial_latent"] = initial_latent | |
if self.args.i2v: | |
noise_shape = [image_or_video_shape[0], image_or_video_shape[1] - 1, *image_or_video_shape[2:]] | |
else: | |
noise_shape = image_or_video_shape.copy() | |
# During training, the number of generated frames should be uniformly sampled from | |
# [21, self.num_training_frames], but still being a multiple of self.num_frame_per_block | |
min_num_frames = 20 if self.args.independent_first_frame else 21 | |
max_num_frames = self.num_training_frames - 1 if self.args.independent_first_frame else self.num_training_frames | |
assert max_num_frames % self.num_frame_per_block == 0 | |
assert min_num_frames % self.num_frame_per_block == 0 | |
max_num_blocks = max_num_frames // self.num_frame_per_block | |
min_num_blocks = min_num_frames // self.num_frame_per_block | |
num_generated_blocks = torch.randint(min_num_blocks, max_num_blocks + 1, (1,), device=self.device) | |
dist.broadcast(num_generated_blocks, src=0) | |
num_generated_blocks = num_generated_blocks.item() | |
num_generated_frames = num_generated_blocks * self.num_frame_per_block | |
if self.args.independent_first_frame and initial_latent is None: | |
num_generated_frames += 1 | |
min_num_frames += 1 | |
# Sync num_generated_frames across all processes | |
noise_shape[1] = num_generated_frames | |
pred_image_or_video, denoised_timestep_from, denoised_timestep_to = self._consistency_backward_simulation( | |
noise=torch.randn(noise_shape, | |
device=self.device, dtype=self.dtype), | |
**conditional_dict, | |
) | |
# Slice last 21 frames | |
if pred_image_or_video.shape[1] > 21: | |
with torch.no_grad(): | |
# Reencode to get image latent | |
latent_to_decode = pred_image_or_video[:, :-20, ...] | |
# Deccode to video | |
pixels = self.vae.decode_to_pixel(latent_to_decode) | |
frame = pixels[:, -1:, ...].to(self.dtype) | |
frame = rearrange(frame, "b t c h w -> b c t h w") | |
# Encode frame to get image latent | |
image_latent = self.vae.encode_to_latent(frame).to(self.dtype) | |
pred_image_or_video_last_21 = torch.cat([image_latent, pred_image_or_video[:, -20:, ...]], dim=1) | |
else: | |
pred_image_or_video_last_21 = pred_image_or_video | |
if num_generated_frames != min_num_frames: | |
# Currently, we do not use gradient for the first chunk, since it contains image latents | |
gradient_mask = torch.ones_like(pred_image_or_video_last_21, dtype=torch.bool) | |
if self.args.independent_first_frame: | |
gradient_mask[:, :1] = False | |
else: | |
gradient_mask[:, :self.num_frame_per_block] = False | |
else: | |
gradient_mask = None | |
pred_image_or_video_last_21 = pred_image_or_video_last_21.to(self.dtype) | |
return pred_image_or_video_last_21, gradient_mask, denoised_timestep_from, denoised_timestep_to | |
def _consistency_backward_simulation( | |
self, | |
noise: torch.Tensor, | |
**conditional_dict: dict | |
) -> torch.Tensor: | |
""" | |
Simulate the generator's input from noise to avoid training/inference mismatch. | |
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details. | |
Here we use the consistency sampler (https://arxiv.org/abs/2303.01469) | |
Input: | |
- noise: a tensor sampled from N(0, 1) with shape [B, F, C, H, W] where the number of frame is 1 for images. | |
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings). | |
Output: | |
- output: a tensor with shape [B, T, F, C, H, W]. | |
T is the total number of timesteps. output[0] is a pure noise and output[i] and i>0 | |
represents the x0 prediction at each timestep. | |
""" | |
if self.inference_pipeline is None: | |
self._initialize_inference_pipeline() | |
return self.inference_pipeline.inference_with_trajectory( | |
noise=noise, **conditional_dict | |
) | |
def _initialize_inference_pipeline(self): | |
""" | |
Lazy initialize the inference pipeline during the first backward simulation run. | |
Here we encapsulate the inference code with a model-dependent outside function. | |
We pass our FSDP-wrapped modules into the pipeline to save memory. | |
""" | |
self.inference_pipeline = SelfForcingTrainingPipeline( | |
denoising_step_list=self.denoising_step_list, | |
scheduler=self.scheduler, | |
generator=self.generator, | |
num_frame_per_block=self.num_frame_per_block, | |
independent_first_frame=self.args.independent_first_frame, | |
same_step_across_blocks=self.args.same_step_across_blocks, | |
last_step_only=self.args.last_step_only, | |
num_max_frames=self.num_training_frames, | |
context_noise=self.args.context_noise | |
) | |