from tqdm import tqdm from typing import List import torch from wan.utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper class BidirectionalDiffusionInferencePipeline(torch.nn.Module): def __init__( self, args, device, generator=None, text_encoder=None, vae=None ): super().__init__() # Step 1: Initialize all models self.generator = WanDiffusionWrapper( **getattr(args, "model_kwargs", {}), is_causal=False) if generator is None else generator self.text_encoder = WanTextEncoder() if text_encoder is None else text_encoder self.vae = WanVAEWrapper() if vae is None else vae # Step 2: Initialize scheduler self.num_train_timesteps = args.num_train_timestep self.sampling_steps = 50 self.sample_solver = 'unipc' self.shift = 8.0 self.args = args def inference( self, noise: torch.Tensor, text_prompts: List[str], return_latents=False ) -> torch.Tensor: """ Perform inference on the given noise and text prompts. Inputs: noise (torch.Tensor): The input noise tensor of shape (batch_size, num_frames, num_channels, height, width). text_prompts (List[str]): The list of text prompts. Outputs: video (torch.Tensor): The generated video tensor of shape (batch_size, num_frames, num_channels, height, width). It is normalized to be in the range [0, 1]. """ conditional_dict = self.text_encoder( text_prompts=text_prompts ) unconditional_dict = self.text_encoder( text_prompts=[self.args.negative_prompt] * len(text_prompts) ) latents = noise sample_scheduler = self._initialize_sample_scheduler(noise) for _, t in enumerate(tqdm(sample_scheduler.timesteps)): latent_model_input = latents timestep = t * torch.ones([latents.shape[0], 21], device=noise.device, dtype=torch.float32) flow_pred_cond, _ = self.generator(latent_model_input, conditional_dict, timestep) flow_pred_uncond, _ = self.generator(latent_model_input, unconditional_dict, timestep) flow_pred = flow_pred_uncond + self.args.guidance_scale * ( flow_pred_cond - flow_pred_uncond) temp_x0 = sample_scheduler.step( flow_pred.unsqueeze(0), t, latents.unsqueeze(0), return_dict=False)[0] latents = temp_x0.squeeze(0) x0 = latents video = self.vae.decode_to_pixel(x0) video = (video * 0.5 + 0.5).clamp(0, 1) del sample_scheduler if return_latents: return video, latents else: return video def _initialize_sample_scheduler(self, noise): if self.sample_solver == 'unipc': sample_scheduler = FlowUniPCMultistepScheduler( num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False) sample_scheduler.set_timesteps( self.sampling_steps, device=noise.device, shift=self.shift) self.timesteps = sample_scheduler.timesteps elif self.sample_solver == 'dpm++': sample_scheduler = FlowDPMSolverMultistepScheduler( num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False) sampling_sigmas = get_sampling_sigmas(self.sampling_steps, self.shift) self.timesteps, _ = retrieve_timesteps( sample_scheduler, device=noise.device, sigmas=sampling_sigmas) else: raise NotImplementedError("Unsupported solver.") return sample_scheduler