Spaces:
Running
on
Zero
Running
on
Zero
from typing import List | |
import torch | |
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper | |
class BidirectionalInferencePipeline(torch.nn.Module): | |
def __init__( | |
self, | |
args, | |
device, | |
generator=None, | |
text_encoder=None, | |
vae=None | |
): | |
super().__init__() | |
# Step 1: Initialize all models | |
self.generator = WanDiffusionWrapper( | |
**getattr(args, "model_kwargs", {}), is_causal=False) if generator is None else generator | |
self.text_encoder = WanTextEncoder() if text_encoder is None else text_encoder | |
self.vae = WanVAEWrapper() if vae is None else vae | |
# Step 2: Initialize all bidirectional wan hyperparmeters | |
self.scheduler = self.generator.get_scheduler() | |
self.denoising_step_list = torch.tensor( | |
args.denoising_step_list, dtype=torch.long, device=device) | |
if self.denoising_step_list[-1] == 0: | |
self.denoising_step_list = self.denoising_step_list[:-1] # remove the zero timestep for inference | |
if args.warp_denoising_step: | |
timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32))) | |
self.denoising_step_list = timesteps[1000 - self.denoising_step_list] | |
def inference(self, noise: torch.Tensor, text_prompts: List[str]) -> torch.Tensor: | |
""" | |
Perform inference on the given noise and text prompts. | |
Inputs: | |
noise (torch.Tensor): The input noise tensor of shape | |
(batch_size, num_frames, num_channels, height, width). | |
text_prompts (List[str]): The list of text prompts. | |
Outputs: | |
video (torch.Tensor): The generated video tensor of shape | |
(batch_size, num_frames, num_channels, height, width). It is normalized to be in the range [0, 1]. | |
""" | |
conditional_dict = self.text_encoder( | |
text_prompts=text_prompts | |
) | |
# initial point | |
noisy_image_or_video = noise | |
# use the last n-1 timesteps to simulate the generator's input | |
for index, current_timestep in enumerate(self.denoising_step_list[:-1]): | |
_, pred_image_or_video = self.generator( | |
noisy_image_or_video=noisy_image_or_video, | |
conditional_dict=conditional_dict, | |
timestep=torch.ones( | |
noise.shape[:2], dtype=torch.long, device=noise.device) * current_timestep | |
) # [B, F, C, H, W] | |
next_timestep = self.denoising_step_list[index + 1] * torch.ones( | |
noise.shape[:2], dtype=torch.long, device=noise.device) | |
noisy_image_or_video = self.scheduler.add_noise( | |
pred_image_or_video.flatten(0, 1), | |
torch.randn_like(pred_image_or_video.flatten(0, 1)), | |
next_timestep.flatten(0, 1) | |
).unflatten(0, noise.shape[:2]) | |
video = self.vae.decode_to_pixel(pred_image_or_video) | |
video = (video * 0.5 + 0.5).clamp(0, 1) | |
return video | |