Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,146 Bytes
0fd2f06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
from tqdm import tqdm
from typing import List
import torch
from wan.utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
class BidirectionalDiffusionInferencePipeline(torch.nn.Module):
def __init__(
self,
args,
device,
generator=None,
text_encoder=None,
vae=None
):
super().__init__()
# Step 1: Initialize all models
self.generator = WanDiffusionWrapper(
**getattr(args, "model_kwargs", {}), is_causal=False) if generator is None else generator
self.text_encoder = WanTextEncoder() if text_encoder is None else text_encoder
self.vae = WanVAEWrapper() if vae is None else vae
# Step 2: Initialize scheduler
self.num_train_timesteps = args.num_train_timestep
self.sampling_steps = 50
self.sample_solver = 'unipc'
self.shift = 8.0
self.args = args
def inference(
self,
noise: torch.Tensor,
text_prompts: List[str],
return_latents=False
) -> torch.Tensor:
"""
Perform inference on the given noise and text prompts.
Inputs:
noise (torch.Tensor): The input noise tensor of shape
(batch_size, num_frames, num_channels, height, width).
text_prompts (List[str]): The list of text prompts.
Outputs:
video (torch.Tensor): The generated video tensor of shape
(batch_size, num_frames, num_channels, height, width). It is normalized to be in the range [0, 1].
"""
conditional_dict = self.text_encoder(
text_prompts=text_prompts
)
unconditional_dict = self.text_encoder(
text_prompts=[self.args.negative_prompt] * len(text_prompts)
)
latents = noise
sample_scheduler = self._initialize_sample_scheduler(noise)
for _, t in enumerate(tqdm(sample_scheduler.timesteps)):
latent_model_input = latents
timestep = t * torch.ones([latents.shape[0], 21], device=noise.device, dtype=torch.float32)
flow_pred_cond, _ = self.generator(latent_model_input, conditional_dict, timestep)
flow_pred_uncond, _ = self.generator(latent_model_input, unconditional_dict, timestep)
flow_pred = flow_pred_uncond + self.args.guidance_scale * (
flow_pred_cond - flow_pred_uncond)
temp_x0 = sample_scheduler.step(
flow_pred.unsqueeze(0),
t,
latents.unsqueeze(0),
return_dict=False)[0]
latents = temp_x0.squeeze(0)
x0 = latents
video = self.vae.decode_to_pixel(x0)
video = (video * 0.5 + 0.5).clamp(0, 1)
del sample_scheduler
if return_latents:
return video, latents
else:
return video
def _initialize_sample_scheduler(self, noise):
if self.sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sample_scheduler.set_timesteps(
self.sampling_steps, device=noise.device, shift=self.shift)
self.timesteps = sample_scheduler.timesteps
elif self.sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(self.sampling_steps, self.shift)
self.timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=noise.device,
sigmas=sampling_sigmas)
else:
raise NotImplementedError("Unsupported solver.")
return sample_scheduler
|