Spaces:
Running
on
Zero
Running
on
Zero
# WIP: Optimizations are coming! | |
from typing import List, Optional, Tuple, Union | |
import cv2 | |
import numpy as np | |
import safetensors.torch | |
import torch | |
import torchvision.transforms.v2 as transforms | |
from diffusers import FlowMatchEulerDiscreteScheduler, HunyuanVideoPipeline | |
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback | |
from diffusers.loaders import HunyuanVideoLoraLoaderMixin | |
from diffusers.models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel | |
from diffusers.models.attention import Attention | |
from diffusers.models.embeddings import apply_rotary_emb | |
from diffusers.models.transformers.transformer_hunyuan_video import HunyuanVideoPatchEmbed, HunyuanVideoTransformer3DModel | |
from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE, retrieve_timesteps | |
from diffusers.pipelines.hunyuan_video.pipeline_output import HunyuanVideoPipelineOutput | |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline | |
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler | |
from diffusers.utils import export_to_video, is_torch_xla_available, load_image, logging, replace_example_docstring | |
from diffusers.utils.state_dict_utils import convert_state_dict_to_diffusers, convert_unet_state_dict_to_peft | |
from diffusers.utils.torch_utils import randn_tensor | |
from diffusers.video_processor import VideoProcessor | |
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict | |
from PIL import Image | |
from typing import Dict, List | |
from typing import Any | |
from typing import Callable | |
import argparse | |
import os | |
import time | |
import random | |
import sys | |
# 20250305 pftq load settings for customization #### | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--base_model_id", type=str, default="hunyuanvideo-community/HunyuanVideo") | |
parser.add_argument("--transformer_model_id", type=str, default="hunyuanvideo-community/HunyuanVideo") | |
parser.add_argument("--lora_path", type=str, default="i2v.sft") | |
parser.add_argument("--use_sage", action="store_true") | |
parser.add_argument("--use_flash", action="store_true") | |
parser.add_argument("--cfg", type=float, default=6.0) | |
parser.add_argument("--num_frames", type=int, default=77) | |
parser.add_argument("--steps", type=int, default=50) | |
parser.add_argument("--seed", type=int, default=-1) | |
parser.add_argument("--prompt", type=str, default="a woman") | |
parser.add_argument("--height", type=int, default=1280) | |
parser.add_argument("--width", type=int, default=720) | |
parser.add_argument("--video_num", type=int, default=1) | |
parser.add_argument("--image1", type=str, default="https://content.dashtoon.ai/stability-images/e524013d-55d4-483a-b80a-dfc51d639158.png") | |
parser.add_argument("--image2", type=str, default="https://content.dashtoon.ai/stability-images/0b29c296-0a90-4b92-96b9-1ed0ae21e480.png") | |
parser.add_argument("--image3", type=str, default="") | |
parser.add_argument("--image4", type=str, default="") | |
parser.add_argument("--image5", type=str, default="") | |
parser.add_argument("--fps", type=int, default=24) | |
parser.add_argument("--mbps", type=float, default=7) | |
parser.add_argument("--color_match", action="store_true") | |
args = parser.parse_args() | |
# 20250305 pftq: from main repo at https://github.com/dashtoon/hunyuan-video-keyframe-control-lora/blob/main/hv_control_lora_inference.py | |
use_sage = False | |
use_flash = False | |
if args.use_sage: | |
try: | |
from sageattention import sageattn, sageattn_varlen | |
use_sage = True | |
except ImportError: | |
sageattn, sageattn_varlen = None, None | |
elif args.use_flash: | |
try: | |
import flash_attn | |
from flash_attn.flash_attn_interface import _flash_attn_forward, flash_attn_varlen_func | |
use_flash = True | |
except ImportError: | |
flash_attn, _flash_attn_forward, flash_attn_varlen_func = None, None, None | |
print("Using SageAtten: "+str(use_sage)) | |
print("Using FlashAttn: "+str(use_flash)) | |
video_transforms = transforms.Compose( | |
[ | |
transforms.Lambda(lambda x: x / 255.0), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
] | |
) | |
def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray: | |
""" | |
Resize the image to the bucket resolution. | |
""" | |
is_pil_image = isinstance(image, Image.Image) | |
if is_pil_image: | |
image_width, image_height = image.size | |
else: | |
image_height, image_width = image.shape[:2] | |
if bucket_reso == (image_width, image_height): | |
return np.array(image) if is_pil_image else image | |
bucket_width, bucket_height = bucket_reso | |
scale_width = bucket_width / image_width | |
scale_height = bucket_height / image_height | |
scale = max(scale_width, scale_height) | |
image_width = int(image_width * scale + 0.5) | |
image_height = int(image_height * scale + 0.5) | |
if scale > 1: | |
image = Image.fromarray(image) if not is_pil_image else image | |
image = image.resize((image_width, image_height), Image.LANCZOS) | |
image = np.array(image) | |
else: | |
image = np.array(image) if is_pil_image else image | |
image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA) | |
# crop the image to the bucket resolution | |
crop_left = (image_width - bucket_width) // 2 | |
crop_top = (image_height - bucket_height) // 2 | |
image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width] | |
return image | |
# 20250305 pftq: from main repo at https://github.com/dashtoon/hunyuan-video-keyframe-control-lora/blob/main/hv_control_lora_inference.py | |
def get_cu_seqlens(attention_mask): | |
"""Calculate cu_seqlens_q, cu_seqlens_kv using attention_mask""" | |
batch_size = attention_mask.shape[0] | |
text_len = attention_mask.sum(dim=-1, dtype=torch.int) | |
max_len = attention_mask.shape[-1] | |
cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda") | |
for i in range(batch_size): | |
s = text_len[i] | |
s1 = i * max_len + s | |
s2 = (i + 1) * max_len | |
cu_seqlens[2 * i + 1] = s1 | |
cu_seqlens[2 * i + 2] = s2 | |
return cu_seqlens | |
class HunyuanVideoFlashAttnProcessor: | |
def __init__(self, use_flash_attn=True, use_sageattn=False): | |
self.use_flash_attn = use_flash_attn | |
self.use_sageattn = use_sageattn | |
if self.use_flash_attn: | |
assert flash_attn is not None, "Flash attention not available" | |
if self.use_sageattn: | |
assert sageattn is not None, "Sage attention not available" | |
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None): | |
if attn.add_q_proj is None and encoder_hidden_states is not None: | |
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) | |
query = attn.to_q(hidden_states) | |
key = attn.to_k(hidden_states) | |
value = attn.to_v(hidden_states) | |
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) | |
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) | |
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) | |
if attn.norm_q is not None: | |
query = attn.norm_q(query) | |
if attn.norm_k is not None: | |
key = attn.norm_k(key) | |
if image_rotary_emb is not None: | |
if attn.add_q_proj is None and encoder_hidden_states is not None: | |
query = torch.cat( | |
[ | |
apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), | |
query[:, :, -encoder_hidden_states.shape[1] :], | |
], | |
dim=2, | |
) | |
key = torch.cat( | |
[ | |
apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), | |
key[:, :, -encoder_hidden_states.shape[1] :], | |
], | |
dim=2, | |
) | |
else: | |
query = apply_rotary_emb(query, image_rotary_emb) | |
key = apply_rotary_emb(key, image_rotary_emb) | |
batch_size = hidden_states.shape[0] | |
img_seq_len = hidden_states.shape[1] | |
txt_seq_len = 0 | |
if attn.add_q_proj is not None and encoder_hidden_states is not None: | |
encoder_query = attn.add_q_proj(encoder_hidden_states) | |
encoder_key = attn.add_k_proj(encoder_hidden_states) | |
encoder_value = attn.add_v_proj(encoder_hidden_states) | |
encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2) | |
encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2) | |
encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2) | |
if attn.norm_added_q is not None: | |
encoder_query = attn.norm_added_q(encoder_query) | |
if attn.norm_added_k is not None: | |
encoder_key = attn.norm_added_k(encoder_key) | |
query = torch.cat([query, encoder_query], dim=2) | |
key = torch.cat([key, encoder_key], dim=2) | |
value = torch.cat([value, encoder_value], dim=2) | |
txt_seq_len = encoder_hidden_states.shape[1] | |
max_seqlen_q = max_seqlen_kv = img_seq_len + txt_seq_len | |
cu_seqlens_q = cu_seqlens_kv = get_cu_seqlens(attention_mask) | |
query = query.transpose(1, 2).reshape(-1, query.shape[1], query.shape[3]) | |
key = key.transpose(1, 2).reshape(-1, key.shape[1], key.shape[3]) | |
value = value.transpose(1, 2).reshape(-1, value.shape[1], value.shape[3]) | |
if self.use_flash_attn: | |
hidden_states = flash_attn_varlen_func( | |
query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv | |
) | |
elif self.use_sageattn: | |
hidden_states = sageattn_varlen(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) | |
else: | |
raise NotImplementedError("Please set use_flash_attn=True or use_sageattn=True") | |
hidden_states = hidden_states.reshape(batch_size, max_seqlen_q, -1) | |
hidden_states = hidden_states.to(query.dtype) | |
if encoder_hidden_states is not None: | |
hidden_states, encoder_hidden_states = ( | |
hidden_states[:, : -encoder_hidden_states.shape[1]], | |
hidden_states[:, -encoder_hidden_states.shape[1] :], | |
) | |
if getattr(attn, "to_out", None) is not None: | |
hidden_states = attn.to_out[0](hidden_states) | |
hidden_states = attn.to_out[1](hidden_states) | |
if getattr(attn, "to_add_out", None) is not None: | |
encoder_hidden_states = attn.to_add_out(encoder_hidden_states) | |
return hidden_states, encoder_hidden_states | |
def call_pipe( | |
pipe, | |
prompt: Union[str, List[str]] = None, | |
prompt_2: Union[str, List[str]] = None, | |
height: int = 720, | |
width: int = 1280, | |
num_frames: int = 129, | |
num_inference_steps: int = 50, | |
sigmas: List[float] = None, | |
guidance_scale: float = 6.0, | |
num_videos_per_prompt: Optional[int] = 1, | |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
latents: Optional[torch.Tensor] = None, | |
prompt_embeds: Optional[torch.Tensor] = None, | |
pooled_prompt_embeds: Optional[torch.Tensor] = None, | |
prompt_attention_mask: Optional[torch.Tensor] = None, | |
output_type: Optional[str] = "pil", | |
return_dict: bool = True, | |
attention_kwargs: Optional[Dict[str, Any]] = None, | |
callback_on_step_end: Optional[Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]] = None, | |
callback_on_step_end_tensor_inputs: List[str] = ["latents"], | |
prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, | |
max_sequence_length: int = 256, | |
image_latents: Optional[torch.Tensor] = None, | |
): | |
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): | |
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs | |
# 1. Check inputs. Raise error if not correct | |
pipe.check_inputs( | |
prompt, | |
prompt_2, | |
height, | |
width, | |
prompt_embeds, | |
callback_on_step_end_tensor_inputs, | |
prompt_template, | |
) | |
pipe._guidance_scale = guidance_scale | |
pipe._attention_kwargs = attention_kwargs | |
pipe._current_timestep = None | |
pipe._interrupt = False | |
device = pipe._execution_device | |
# 2. Define call parameters | |
if prompt is not None and isinstance(prompt, str): | |
batch_size = 1 | |
elif prompt is not None and isinstance(prompt, list): | |
batch_size = len(prompt) | |
else: | |
batch_size = prompt_embeds.shape[0] | |
# 3. Encode input prompt | |
prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = pipe.encode_prompt( | |
prompt=prompt, | |
prompt_2=prompt_2, | |
prompt_template=prompt_template, | |
num_videos_per_prompt=num_videos_per_prompt, | |
prompt_embeds=prompt_embeds, | |
pooled_prompt_embeds=pooled_prompt_embeds, | |
prompt_attention_mask=prompt_attention_mask, | |
device=device, | |
max_sequence_length=max_sequence_length, | |
) | |
transformer_dtype = pipe.transformer.dtype | |
prompt_embeds = prompt_embeds.to(transformer_dtype) | |
prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) | |
if pooled_prompt_embeds is not None: | |
pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) | |
# 4. Prepare timesteps | |
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas | |
timesteps, num_inference_steps = retrieve_timesteps( | |
pipe.scheduler, | |
num_inference_steps, | |
device, | |
sigmas=sigmas, | |
) | |
# 5. Prepare latent variables | |
num_channels_latents = pipe.transformer.config.in_channels | |
num_latent_frames = (num_frames - 1) // pipe.vae_scale_factor_temporal + 1 | |
latents = pipe.prepare_latents( | |
batch_size * num_videos_per_prompt, | |
num_channels_latents, | |
height, | |
width, | |
num_latent_frames, | |
torch.float32, | |
device, | |
generator, | |
latents, | |
) | |
# 6. Prepare guidance condition | |
guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 | |
# 7. Denoising loop | |
num_warmup_steps = len(timesteps) - num_inference_steps * pipe.scheduler.order | |
pipe._num_timesteps = len(timesteps) | |
# 20250305 pftq: added to properly offload to CPU, was out of memory otherwise | |
pipe.text_encoder.to("cpu") | |
pipe.text_encoder_2.to("cpu") | |
torch.cuda.empty_cache() | |
with pipe.progress_bar(total=num_inference_steps) as progress_bar: | |
for i, t in enumerate(timesteps): | |
if pipe.interrupt: | |
continue | |
pipe._current_timestep = t | |
latent_model_input = latents.to(transformer_dtype) | |
timestep = t.expand(latents.shape[0]).to(latents.dtype) | |
noise_pred = pipe.transformer( | |
hidden_states=torch.cat([latent_model_input, image_latents], dim=1), | |
timestep=timestep, | |
encoder_hidden_states=prompt_embeds, | |
encoder_attention_mask=prompt_attention_mask, | |
pooled_projections=pooled_prompt_embeds, | |
guidance=guidance, | |
attention_kwargs=attention_kwargs, | |
return_dict=False, | |
)[0] | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents = pipe.scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
if callback_on_step_end is not None: | |
callback_kwargs = {} | |
for k in callback_on_step_end_tensor_inputs: | |
callback_kwargs[k] = locals()[k] | |
callback_outputs = callback_on_step_end(pipe, i, t, callback_kwargs) | |
latents = callback_outputs.pop("latents", latents) | |
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) | |
# call the callback, if provided | |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipe.scheduler.order == 0): | |
progress_bar.update() | |
pipe._current_timestep = None | |
if not output_type == "latent": | |
latents = latents.to(pipe.vae.dtype) / pipe.vae.config.scaling_factor | |
video = pipe.vae.decode(latents, return_dict=False)[0] | |
video = pipe.video_processor.postprocess_video(video, output_type=output_type) | |
else: | |
video = latents | |
# Offload all models | |
pipe.maybe_free_model_hooks() | |
if not return_dict: | |
return (video,) | |
return HunyuanVideoPipelineOutput(frames=video) | |
#20250305 pftq: customizable bitrate | |
# Function to check if FFmpeg is installed | |
import subprocess # For FFmpeg functionality | |
def is_ffmpeg_installed(): | |
try: | |
subprocess.run(["ffmpeg", "-version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) | |
return True | |
except (subprocess.CalledProcessError, FileNotFoundError): | |
return False | |
# FFmpeg-based video saving with bitrate control | |
def save_video_with_ffmpeg(frames, output_path, fps, bitrate_mbps, metadata_comment=None): | |
frames = [np.array(frame) for frame in frames] | |
height, width, _ = frames[0].shape | |
bitrate = f"{bitrate_mbps}M" | |
cmd = [ | |
"ffmpeg", | |
"-y", | |
"-f", "rawvideo", | |
"-vcodec", "rawvideo", | |
"-s", f"{width}x{height}", | |
"-pix_fmt", "rgb24", | |
"-r", str(fps), | |
"-i", "-", | |
"-c:v", "libx264", | |
"-b:v", bitrate, | |
"-pix_fmt", "yuv420p", | |
"-preset", "medium", | |
] | |
# Add metadata comment if provided | |
if metadata_comment: | |
cmd.extend(["-metadata", f"comment={metadata_comment}"]) | |
cmd.append(output_path) | |
process = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.PIPE) | |
for frame in frames: | |
process.stdin.write(frame.tobytes()) | |
process.stdin.close() | |
process.wait() | |
stderr_output = process.stderr.read().decode() | |
if process.returncode != 0: | |
print(f"FFmpeg error: {stderr_output}") | |
else: | |
print(f"Video saved to {output_path} with FFmpeg") | |
# Fallback OpenCV-based video saving | |
def save_video_with_opencv(frames, output_path, fps, bitrate_mbps): | |
frames = [np.array(frame) for frame in frames] | |
height, width, _ = frames[0].shape | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
# Note: cv2.CAP_PROP_BITRATE is not supported, so bitrate_mbps is ignored | |
for frame in frames: | |
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) # Convert RGB to BGR for OpenCV | |
writer.write(frame) | |
writer.release() | |
print(f"Video saved to {output_path} with OpenCV (bitrate control unavailable)") | |
# Wrapper to choose between FFmpeg and OpenCV | |
def save_video_with_quality(frames, output_path, fps, bitrate_mbps, metadata_comment=None): | |
if is_ffmpeg_installed(): | |
save_video_with_ffmpeg(frames, output_path, fps, bitrate_mbps, metadata_comment) | |
else: | |
print("FFmpeg not found. Falling back to OpenCV (bitrate not customizable).") | |
save_video_with_opencv(frames, output_path, fps, bitrate_mbps) | |
# Reconstruct command-line with quotes and backslash+linebreak after argument-value pairs | |
def reconstruct_command_line(args, argv): | |
cmd_parts = [argv[0]] # Start with script name | |
args_dict = vars(args) # Convert args to dict | |
i = 1 | |
while i < len(argv): | |
arg = argv[i] | |
if arg.startswith("--"): | |
key = arg[2:] | |
if key in args_dict: | |
value = args_dict[key] | |
if isinstance(value, bool): | |
if value: | |
cmd_parts.append(arg) # Boolean flag | |
i += 1 | |
else: | |
# Combine argument and value into one part | |
if i + 1 < len(argv) and not argv[i + 1].startswith("--"): | |
next_val = argv[i + 1] | |
if isinstance(value, str): | |
cmd_parts.append(f'{arg} "{value}"') # Quote strings | |
else: | |
cmd_parts.append(f"{arg} {value}") # No quotes for numbers | |
i += 2 | |
else: | |
# Handle missing value in argv (use parsed args) | |
if isinstance(value, str): | |
cmd_parts.append(f'{arg} "{value}"') | |
else: | |
cmd_parts.append(f"{arg} {value}") | |
i += 1 | |
else: | |
i += 1 | |
# Build multi-line string with backslash and newline except for the last part | |
if len(cmd_parts) > 1: | |
result = "" | |
for j, part in enumerate(cmd_parts): | |
if j < len(cmd_parts) - 1: | |
result += part + " \\\n" | |
else: | |
result += part # No trailing backslash on last part | |
return result | |
return cmd_parts[0] # Single arg case | |
# start executing here ################### | |
print("Initializing model...") | |
transformer_subfolder = "transformer" | |
if args.transformer_model_id == "Skywork/SkyReels-V1-Hunyuan-I2V": | |
transformer_subfolder = "" # 20250305 pftq: Error otherwise - Skywork/SkyReels-V1-Hunyuan-I2V does not appear to have a file named config.json. | |
transformer = HunyuanVideoTransformer3DModel.from_pretrained(args.transformer_model_id, subfolder=transformer_subfolder, torch_dtype=torch.bfloat16) | |
pipe = HunyuanVideoPipeline.from_pretrained(args.base_model_id, transformer=transformer, torch_dtype=torch.bfloat16) | |
# Enable memory savings | |
pipe.vae.enable_slicing() | |
pipe.vae.enable_tiling() | |
pipe.enable_model_cpu_offload() | |
# Apply flash attention to all transformer blocks | |
if use_sage or use_flash: | |
for block in pipe.transformer.transformer_blocks + pipe.transformer.single_transformer_blocks: | |
block.attn.processor = HunyuanVideoFlashAttnProcessor(use_flash_attn=use_flash, use_sageattn=use_sage) | |
with torch.no_grad(): # enable image inputs | |
initial_input_channels = pipe.transformer.config.in_channels | |
new_img_in = HunyuanVideoPatchEmbed( | |
patch_size=(pipe.transformer.config.patch_size_t, pipe.transformer.config.patch_size, pipe.transformer.config.patch_size), | |
in_chans=pipe.transformer.config.in_channels * 2, | |
embed_dim=pipe.transformer.config.num_attention_heads * pipe.transformer.config.attention_head_dim, | |
) | |
new_img_in = new_img_in.to(pipe.device, dtype=pipe.dtype) | |
new_img_in.proj.weight.zero_() | |
new_img_in.proj.weight[:, :initial_input_channels].copy_(pipe.transformer.x_embedder.proj.weight) | |
if pipe.transformer.x_embedder.proj.bias is not None: | |
new_img_in.proj.bias.copy_(pipe.transformer.x_embedder.proj.bias) | |
pipe.transformer.x_embedder = new_img_in | |
print("Loading lora...") | |
lora_state_dict = pipe.lora_state_dict(args.lora_path) | |
transformer_lora_state_dict = {f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") and "lora" in k} | |
pipe.load_lora_into_transformer(transformer_lora_state_dict, transformer=pipe.transformer, adapter_name="i2v", _pipeline=pipe) | |
pipe.set_adapters(["i2v"], adapter_weights=[1.0]) | |
pipe.fuse_lora(components=["transformer"], lora_scale=1.0, adapter_names=["i2v"]) | |
pipe.unload_lora_weights() | |
print("Loading images...") | |
cond_frame1 = load_image(args.image1) | |
cond_frame2 = load_image(args.image2) | |
cond_frame1 = resize_image_to_bucket(cond_frame1, bucket_reso=(args.width, args.height)) | |
cond_frame2 = resize_image_to_bucket(cond_frame2, bucket_reso=(args.width, args.height)) | |
cond_video = np.zeros(shape=(args.num_frames, args.height, args.width, 3)) | |
# 20250305 pftq: Optional 3rd-5th frame, sadly doesn't work so easily, needs more code | |
cond_frame3 = None | |
cond_frame4 = None | |
cond_frame5 = None | |
if args.image3 != "": | |
cond_frame3 = load_image(args.image3) | |
cond_frame3 = resize_image_to_bucket(cond_frame3, bucket_reso=(args.width, args.height)) | |
if args.image4 !="": | |
cond_frame4 = load_image(args.image4) | |
cond_frame4 = resize_image_to_bucket(cond_frame4, bucket_reso=(args.width, args.height)) | |
if args.image5 !="": | |
cond_frame5 = load_image(args.image5) | |
cond_frame5 = resize_image_to_bucket(cond_frame5, bucket_reso=(args.width, args.height)) | |
if args.image5 != "" and args.image4 != "" and args.image3 !="" and args.image2 !="": | |
cond_video[0] = np.array(cond_frame1) | |
cond_video[args.num_frames//4] = np.array(cond_frame2) | |
cond_video[(args.num_frames * 2 )//4] = np.array(cond_frame3) | |
cond_video[(args.num_frames * 3 )//4] = np.array(cond_frame4) | |
cond_video[args.num_frames -1] = np.array(cond_frame5) | |
elif args.image4 != "" and args.image3 !="" and args.image2 !="": | |
cond_video[0] = np.array(cond_frame1) | |
cond_video[args.num_frames//3] = np.array(cond_frame2) | |
cond_video[(args.num_frames * 2 )//3] = np.array(cond_frame3) | |
cond_video[args.num_frames -1] = np.array(cond_frame4) | |
elif args.image3 != "" and args.image2 !="": | |
cond_video[0] = np.array(cond_frame1) | |
cond_video[args.num_frames//2] = np.array(cond_frame2) | |
cond_video[args.num_frames -1] = np.array(cond_frame3) | |
else: | |
cond_video[0] = np.array(cond_frame1) | |
cond_video[args.num_frames -1] = np.array(cond_frame2) | |
cond_video = torch.from_numpy(cond_video.copy()).permute(0, 3, 1, 2) | |
cond_video = torch.stack([video_transforms(x) for x in cond_video], dim=0).unsqueeze(0) | |
with torch.no_grad(): | |
image_or_video = cond_video.to(device="cuda", dtype=pipe.dtype) | |
image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W] | |
cond_latents = pipe.vae.encode(image_or_video).latent_dist.sample() | |
cond_latents = cond_latents * pipe.vae.config.scaling_factor | |
cond_latents = cond_latents.to(dtype=pipe.dtype) | |
for idx in range(args.video_num): # 20250305 pftq: for loop for multiple videos per batch with varying seeds | |
if args.seed == -1 or idx > 0: # 20250305 pftq: seed argument ignored if asking for more than one video | |
random.seed(time.time()) | |
args.seed = int(random.randrange(4294967294)) | |
#20250223 pftq: More useful filename and higher customizable bitrate | |
from datetime import datetime | |
now = datetime.now() | |
formatted_time = now.strftime('%Y-%m-%d_%H-%M-%S') | |
video_out_file = formatted_time+f"_hunyuankeyframe_{args.width}-{args.num_frames}f_cfg-{args.cfg}_steps-{args.steps}_seed-{args.seed}_{args.prompt[:40].replace('/','')}_{idx}" | |
command_line = reconstruct_command_line(args, sys.argv) # 20250307: Store the full command-line used in the mp4 comment with quotes | |
#print(f"Command-line received:\n{command_line}") | |
print("Starting video generation #"+str(idx)+" for "+video_out_file) | |
video = call_pipe( | |
pipe, | |
prompt=args.prompt, | |
num_frames=args.num_frames, | |
num_inference_steps=args.steps, | |
image_latents=cond_latents, | |
width=args.width, | |
height=args.height, | |
guidance_scale=args.cfg, | |
generator=torch.Generator(device="cuda").manual_seed(args.seed), | |
).frames[0] | |
# 20250305 pftq: Color match with direct MKL and temporal smoothing | |
if args.color_match: | |
#save_video_with_quality(video, f"{video_out_file}_raw.mp4", args.fps, args.mbps) | |
print("Applying color matching to video...") | |
from color_matcher import ColorMatcher | |
from color_matcher.io_handler import load_img_file | |
from color_matcher.normalizer import Normalizer | |
# Load the reference image (image1) | |
ref_img = load_img_file(args.image1) # Original load | |
cm = ColorMatcher() | |
matched_video = [] | |
for frame in video: | |
frame_rgb = np.array(frame) # Direct PIL to numpy | |
matched_frame = cm.transfer(src=frame_rgb, ref=ref_img, method='mkl') | |
matched_frame = Normalizer(matched_frame).uint8_norm() | |
matched_video.append(matched_frame) | |
video = matched_video | |
# END OF COLOR MATCHING | |
print("Saving "+video_out_file) | |
#export_to_video(final_video, "output.mp4", fps=24) | |
save_video_with_quality(video, f"{video_out_file}.mp4", args.fps, args.mbps, command_line) |