from typing import List import torch from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper class BidirectionalInferencePipeline(torch.nn.Module): def __init__( self, args, device, generator=None, text_encoder=None, vae=None ): super().__init__() # Step 1: Initialize all models self.generator = WanDiffusionWrapper( **getattr(args, "model_kwargs", {}), is_causal=False) if generator is None else generator self.text_encoder = WanTextEncoder() if text_encoder is None else text_encoder self.vae = WanVAEWrapper() if vae is None else vae # Step 2: Initialize all bidirectional wan hyperparmeters self.scheduler = self.generator.get_scheduler() self.denoising_step_list = torch.tensor( args.denoising_step_list, dtype=torch.long, device=device) if self.denoising_step_list[-1] == 0: self.denoising_step_list = self.denoising_step_list[:-1] # remove the zero timestep for inference if args.warp_denoising_step: timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32))) self.denoising_step_list = timesteps[1000 - self.denoising_step_list] def inference(self, noise: torch.Tensor, text_prompts: List[str]) -> torch.Tensor: """ Perform inference on the given noise and text prompts. Inputs: noise (torch.Tensor): The input noise tensor of shape (batch_size, num_frames, num_channels, height, width). text_prompts (List[str]): The list of text prompts. Outputs: video (torch.Tensor): The generated video tensor of shape (batch_size, num_frames, num_channels, height, width). It is normalized to be in the range [0, 1]. """ conditional_dict = self.text_encoder( text_prompts=text_prompts ) # initial point noisy_image_or_video = noise # use the last n-1 timesteps to simulate the generator's input for index, current_timestep in enumerate(self.denoising_step_list[:-1]): _, pred_image_or_video = self.generator( noisy_image_or_video=noisy_image_or_video, conditional_dict=conditional_dict, timestep=torch.ones( noise.shape[:2], dtype=torch.long, device=noise.device) * current_timestep ) # [B, F, C, H, W] next_timestep = self.denoising_step_list[index + 1] * torch.ones( noise.shape[:2], dtype=torch.long, device=noise.device) noisy_image_or_video = self.scheduler.add_noise( pred_image_or_video.flatten(0, 1), torch.randn_like(pred_image_or_video.flatten(0, 1)), next_timestep.flatten(0, 1) ).unflatten(0, noise.shape[:2]) video = self.vae.decode_to_pixel(pred_image_or_video) video = (video * 0.5 + 0.5).clamp(0, 1) return video