Spaces:
Running
Running
File size: 13,799 Bytes
80ebcb3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 |
import os
from typing import Any, Dict, List, Optional, Tuple
import torch
from accelerate import init_empty_weights
from diffusers import (
AutoencoderKLWan,
FlowMatchEulerDiscreteScheduler,
WanImageToVideoPipeline,
WanPipeline,
WanTransformer3DModel,
)
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from PIL.Image import Image
from transformers import AutoModel, AutoTokenizer, UMT5EncoderModel
from ... import data
from ... import functional as FF
from ...logging import get_logger
from ...processors import ProcessorMixin, T5Processor
from ...typing import ArtifactType, SchedulerType
from ...utils import get_non_null_items
from ..modeling_utils import ModelSpecification
logger = get_logger()
class WanLatentEncodeProcessor(ProcessorMixin):
r"""
Processor to encode image/video into latents using the Wan VAE.
Args:
output_names (`List[str]`):
The names of the outputs that the processor returns. The outputs are in the following order:
- latents: The latents of the input image/video.
- num_frames: The number of frames in the input video.
- height: The height of the input image/video.
- width: The width of the input image/video.
- latents_mean: The latent channel means from the VAE state dict.
- latents_std: The latent channel standard deviations from the VAE state dict.
"""
def __init__(self, output_names: List[str]):
super().__init__()
self.output_names = output_names
assert len(self.output_names) == 1
def forward(
self,
vae: AutoencoderKLWan,
image: Optional[torch.Tensor] = None,
video: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
compute_posterior: bool = True,
) -> Dict[str, torch.Tensor]:
device = vae.device
dtype = vae.dtype
if image is not None:
video = image.unsqueeze(1)
assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor"
video = video.to(device=device, dtype=vae.dtype)
video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W]
if compute_posterior:
latents = vae.encode(video).latent_dist.sample(generator=generator)
latents = latents.to(dtype=dtype)
else:
# TODO(aryan): refactor in diffusers to have use_slicing attribute
# if vae.use_slicing and video.shape[0] > 1:
# encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)]
# moments = torch.cat(encoded_slices)
# else:
# moments = vae._encode(video)
moments = vae._encode(video)
latents = moments.to(dtype=dtype)
return {self.output_names[0]: latents}
class WanModelSpecification(ModelSpecification):
def __init__(
self,
pretrained_model_name_or_path: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
tokenizer_id: Optional[str] = None,
text_encoder_id: Optional[str] = None,
transformer_id: Optional[str] = None,
vae_id: Optional[str] = None,
text_encoder_dtype: torch.dtype = torch.bfloat16,
transformer_dtype: torch.dtype = torch.bfloat16,
vae_dtype: torch.dtype = torch.bfloat16,
revision: Optional[str] = None,
cache_dir: Optional[str] = None,
condition_model_processors: List[ProcessorMixin] = None,
latent_model_processors: List[ProcessorMixin] = None,
**kwargs,
) -> None:
super().__init__(
pretrained_model_name_or_path=pretrained_model_name_or_path,
tokenizer_id=tokenizer_id,
text_encoder_id=text_encoder_id,
transformer_id=transformer_id,
vae_id=vae_id,
text_encoder_dtype=text_encoder_dtype,
transformer_dtype=transformer_dtype,
vae_dtype=vae_dtype,
revision=revision,
cache_dir=cache_dir,
)
if condition_model_processors is None:
condition_model_processors = [T5Processor(["prompt_embeds", "prompt_attention_mask"])]
if latent_model_processors is None:
latent_model_processors = [WanLatentEncodeProcessor(["latents"])]
self.condition_model_processors = condition_model_processors
self.latent_model_processors = latent_model_processors
@property
def _resolution_dim_keys(self):
# TODO
return {
"latents": (2, 3, 4),
}
def load_condition_models(self) -> Dict[str, torch.nn.Module]:
if self.tokenizer_id is not None:
tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir
)
else:
tokenizer = AutoTokenizer.from_pretrained(
self.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=self.revision,
cache_dir=self.cache_dir,
)
if self.text_encoder_id is not None:
text_encoder = AutoModel.from_pretrained(
self.text_encoder_id,
torch_dtype=self.text_encoder_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
)
else:
text_encoder = UMT5EncoderModel.from_pretrained(
self.pretrained_model_name_or_path,
subfolder="text_encoder",
torch_dtype=self.text_encoder_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
)
return {"tokenizer": tokenizer, "text_encoder": text_encoder}
def load_latent_models(self) -> Dict[str, torch.nn.Module]:
if self.vae_id is not None:
vae = AutoencoderKLWan.from_pretrained(
self.vae_id,
torch_dtype=self.vae_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
)
else:
vae = AutoencoderKLWan.from_pretrained(
self.pretrained_model_name_or_path,
subfolder="vae",
torch_dtype=self.vae_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
)
return {"vae": vae}
def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
if self.transformer_id is not None:
transformer = WanTransformer3DModel.from_pretrained(
self.transformer_id,
torch_dtype=self.transformer_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
)
else:
transformer = WanTransformer3DModel.from_pretrained(
self.pretrained_model_name_or_path,
subfolder="transformer",
torch_dtype=self.transformer_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {"transformer": transformer, "scheduler": scheduler}
def load_pipeline(
self,
tokenizer: Optional[AutoTokenizer] = None,
text_encoder: Optional[UMT5EncoderModel] = None,
transformer: Optional[WanTransformer3DModel] = None,
vae: Optional[AutoencoderKLWan] = None,
scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
enable_slicing: bool = False,
enable_tiling: bool = False,
enable_model_cpu_offload: bool = False,
training: bool = False,
**kwargs,
) -> WanPipeline:
components = {
"tokenizer": tokenizer,
"text_encoder": text_encoder,
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
}
components = get_non_null_items(components)
pipe = WanPipeline.from_pretrained(
self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
)
pipe.text_encoder.to(self.text_encoder_dtype)
pipe.vae.to(self.vae_dtype)
if not training:
pipe.transformer.to(self.transformer_dtype)
# TODO(aryan): add support in diffusers
# if enable_slicing:
# pipe.vae.enable_slicing()
# if enable_tiling:
# pipe.vae.enable_tiling()
if enable_model_cpu_offload:
pipe.enable_model_cpu_offload()
return pipe
@torch.no_grad()
def prepare_conditions(
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
caption: str,
max_sequence_length: int = 512,
**kwargs,
) -> Dict[str, Any]:
conditions = {
"tokenizer": tokenizer,
"text_encoder": text_encoder,
"caption": caption,
"max_sequence_length": max_sequence_length,
**kwargs,
}
input_keys = set(conditions.keys())
conditions = super().prepare_conditions(**conditions)
conditions = {k: v for k, v in conditions.items() if k not in input_keys}
conditions.pop("prompt_attention_mask", None)
return conditions
@torch.no_grad()
def prepare_latents(
self,
vae: AutoencoderKLWan,
image: Optional[torch.Tensor] = None,
video: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
compute_posterior: bool = True,
**kwargs,
) -> Dict[str, torch.Tensor]:
conditions = {
"vae": vae,
"image": image,
"video": video,
"generator": generator,
"compute_posterior": compute_posterior,
**kwargs,
}
input_keys = set(conditions.keys())
conditions = super().prepare_latents(**conditions)
conditions = {k: v for k, v in conditions.items() if k not in input_keys}
return conditions
def forward(
self,
transformer: WanTransformer3DModel,
condition_model_conditions: Dict[str, torch.Tensor],
latent_model_conditions: Dict[str, torch.Tensor],
sigmas: torch.Tensor,
generator: Optional[torch.Generator] = None,
compute_posterior: bool = True,
**kwargs,
) -> Tuple[torch.Tensor, ...]:
if compute_posterior:
latents = latent_model_conditions.pop("latents")
else:
posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"))
latents = posterior.sample(generator=generator)
del posterior
noise = torch.zeros_like(latents).normal_(generator=generator)
noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
condition_model_conditions["encoder_hidden_states"] = condition_model_conditions.pop("prompt_embeds")
timesteps = (sigmas.flatten() * 1000.0).long()
pred = transformer(
**latent_model_conditions,
**condition_model_conditions,
timestep=timesteps,
return_dict=False,
)[0]
target = FF.flow_match_target(noise, latents)
return pred, target, sigmas
def validation(
self,
pipeline: WanPipeline,
prompt: str,
image: Optional[Image] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_frames: Optional[int] = None,
num_inference_steps: int = 50,
generator: Optional[torch.Generator] = None,
**kwargs,
) -> List[ArtifactType]:
if image is not None:
pipeline = WanImageToVideoPipeline.from_pipe(pipeline)
generation_kwargs = {
"prompt": prompt,
"image": image,
"height": height,
"width": width,
"num_frames": num_frames,
"num_inference_steps": num_inference_steps,
"generator": generator,
"return_dict": True,
"output_type": "pil",
}
generation_kwargs = get_non_null_items(generation_kwargs)
video = pipeline(**generation_kwargs).frames[0]
return [data.VideoArtifact(value=video)]
def _save_lora_weights(
self,
directory: str,
transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
scheduler: Optional[SchedulerType] = None,
*args,
**kwargs,
) -> None:
# TODO(aryan): this needs refactoring
if transformer_state_dict is not None:
WanPipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True)
if scheduler is not None:
scheduler.save_pretrained(os.path.join(directory, "scheduler"))
def _save_model(
self,
directory: str,
transformer: WanTransformer3DModel,
transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
scheduler: Optional[SchedulerType] = None,
) -> None:
# TODO(aryan): this needs refactoring
if transformer_state_dict is not None:
with init_empty_weights():
transformer_copy = WanTransformer3DModel.from_config(transformer.config)
transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
if scheduler is not None:
scheduler.save_pretrained(os.path.join(directory, "scheduler"))
|