LPX55 commited on
Commit
3f44783
·
verified ·
1 Parent(s): c747be5

Create alt_gen.py

Browse files
Files changed (1) hide show
  1. alt_gen.py +668 -0
alt_gen.py ADDED
@@ -0,0 +1,668 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import safetensors.torch
6
+ import torch
7
+ import torchvision.transforms.v2 as transforms
8
+ from diffusers import FlowMatchEulerDiscreteScheduler, HunyuanVideoPipeline
9
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
10
+ from diffusers.loaders import HunyuanVideoLoraLoaderMixin
11
+ from diffusers.models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
12
+ from diffusers.models.attention import Attention
13
+ from diffusers.models.embeddings import apply_rotary_emb
14
+ from diffusers.models.transformers.transformer_hunyuan_video import HunyuanVideoPatchEmbed, HunyuanVideoTransformer3DModel
15
+ from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE, retrieve_timesteps
16
+ from diffusers.pipelines.hunyuan_video.pipeline_output import HunyuanVideoPipelineOutput
17
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
18
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
19
+ from diffusers.utils import export_to_video, is_torch_xla_available, load_image, logging, replace_example_docstring
20
+ from diffusers.utils.state_dict_utils import convert_state_dict_to_diffusers, convert_unet_state_dict_to_peft
21
+ from diffusers.utils.torch_utils import randn_tensor
22
+ from diffusers.video_processor import VideoProcessor
23
+ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
24
+ from PIL import Image
25
+ from typing import Dict, List
26
+ from typing import Any
27
+ from typing import Callable
28
+ import argparse
29
+ import os
30
+ import time
31
+ import random
32
+ import sys
33
+
34
+ # 20250305 pftq load settings for customization ####
35
+ parser = argparse.ArgumentParser()
36
+ parser.add_argument("--base_model_id", type=str, default="hunyuanvideo-community/HunyuanVideo")
37
+ parser.add_argument("--transformer_model_id", type=str, default="hunyuanvideo-community/HunyuanVideo")
38
+ parser.add_argument("--lora_path", type=str, default="i2v.sft")
39
+ parser.add_argument("--use_sage", action="store_true")
40
+ parser.add_argument("--use_flash", action="store_true")
41
+ parser.add_argument("--cfg", type=float, default=6.0)
42
+ parser.add_argument("--num_frames", type=int, default=77)
43
+ parser.add_argument("--steps", type=int, default=50)
44
+ parser.add_argument("--seed", type=int, default=-1)
45
+ parser.add_argument("--prompt", type=str, default="a woman")
46
+ parser.add_argument("--height", type=int, default=1280)
47
+ parser.add_argument("--width", type=int, default=720)
48
+ parser.add_argument("--video_num", type=int, default=1)
49
+ parser.add_argument("--image1", type=str, default="https://content.dashtoon.ai/stability-images/e524013d-55d4-483a-b80a-dfc51d639158.png")
50
+ parser.add_argument("--image2", type=str, default="https://content.dashtoon.ai/stability-images/0b29c296-0a90-4b92-96b9-1ed0ae21e480.png")
51
+ parser.add_argument("--image3", type=str, default="")
52
+ parser.add_argument("--image4", type=str, default="")
53
+ parser.add_argument("--image5", type=str, default="")
54
+ parser.add_argument("--fps", type=int, default=24)
55
+ parser.add_argument("--mbps", type=float, default=7)
56
+ parser.add_argument("--color_match", action="store_true")
57
+
58
+ args = parser.parse_args()
59
+
60
+ # 20250305 pftq: from main repo at https://github.com/dashtoon/hunyuan-video-keyframe-control-lora/blob/main/hv_control_lora_inference.py
61
+ use_sage = False
62
+ use_flash = False
63
+ if args.use_sage:
64
+ try:
65
+ from sageattention import sageattn, sageattn_varlen
66
+ use_sage = True
67
+ except ImportError:
68
+ sageattn, sageattn_varlen = None, None
69
+ elif args.use_flash:
70
+ try:
71
+ import flash_attn
72
+ from flash_attn.flash_attn_interface import _flash_attn_forward, flash_attn_varlen_func
73
+ use_flash = True
74
+ except ImportError:
75
+ flash_attn, _flash_attn_forward, flash_attn_varlen_func = None, None, None
76
+ print("Using SageAtten: "+str(use_sage))
77
+ print("Using FlashAttn: "+str(use_flash))
78
+
79
+
80
+ video_transforms = transforms.Compose(
81
+ [
82
+ transforms.Lambda(lambda x: x / 255.0),
83
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
84
+ ]
85
+ )
86
+
87
+
88
+ def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray:
89
+ """
90
+ Resize the image to the bucket resolution.
91
+ """
92
+ is_pil_image = isinstance(image, Image.Image)
93
+ if is_pil_image:
94
+ image_width, image_height = image.size
95
+ else:
96
+ image_height, image_width = image.shape[:2]
97
+
98
+ if bucket_reso == (image_width, image_height):
99
+ return np.array(image) if is_pil_image else image
100
+
101
+ bucket_width, bucket_height = bucket_reso
102
+
103
+ scale_width = bucket_width / image_width
104
+ scale_height = bucket_height / image_height
105
+ scale = max(scale_width, scale_height)
106
+ image_width = int(image_width * scale + 0.5)
107
+ image_height = int(image_height * scale + 0.5)
108
+
109
+ if scale > 1:
110
+ image = Image.fromarray(image) if not is_pil_image else image
111
+ image = image.resize((image_width, image_height), Image.LANCZOS)
112
+ image = np.array(image)
113
+ else:
114
+ image = np.array(image) if is_pil_image else image
115
+ image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA)
116
+
117
+ # crop the image to the bucket resolution
118
+ crop_left = (image_width - bucket_width) // 2
119
+ crop_top = (image_height - bucket_height) // 2
120
+ image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width]
121
+
122
+ return image
123
+
124
+ # 20250305 pftq: from main repo at https://github.com/dashtoon/hunyuan-video-keyframe-control-lora/blob/main/hv_control_lora_inference.py
125
+ def get_cu_seqlens(attention_mask):
126
+ """Calculate cu_seqlens_q, cu_seqlens_kv using attention_mask"""
127
+ batch_size = attention_mask.shape[0]
128
+ text_len = attention_mask.sum(dim=-1, dtype=torch.int)
129
+ max_len = attention_mask.shape[-1]
130
+
131
+ cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
132
+
133
+ for i in range(batch_size):
134
+ s = text_len[i]
135
+ s1 = i * max_len + s
136
+ s2 = (i + 1) * max_len
137
+ cu_seqlens[2 * i + 1] = s1
138
+ cu_seqlens[2 * i + 2] = s2
139
+
140
+ return cu_seqlens
141
+ class HunyuanVideoFlashAttnProcessor:
142
+ def __init__(self, use_flash_attn=True, use_sageattn=False):
143
+ self.use_flash_attn = use_flash_attn
144
+ self.use_sageattn = use_sageattn
145
+ if self.use_flash_attn:
146
+ assert flash_attn is not None, "Flash attention not available"
147
+ if self.use_sageattn:
148
+ assert sageattn is not None, "Sage attention not available"
149
+
150
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None):
151
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
152
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
153
+
154
+ query = attn.to_q(hidden_states)
155
+ key = attn.to_k(hidden_states)
156
+ value = attn.to_v(hidden_states)
157
+
158
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
159
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
160
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
161
+
162
+ if attn.norm_q is not None:
163
+ query = attn.norm_q(query)
164
+ if attn.norm_k is not None:
165
+ key = attn.norm_k(key)
166
+
167
+ if image_rotary_emb is not None:
168
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
169
+ query = torch.cat(
170
+ [
171
+ apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
172
+ query[:, :, -encoder_hidden_states.shape[1] :],
173
+ ],
174
+ dim=2,
175
+ )
176
+ key = torch.cat(
177
+ [
178
+ apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
179
+ key[:, :, -encoder_hidden_states.shape[1] :],
180
+ ],
181
+ dim=2,
182
+ )
183
+ else:
184
+ query = apply_rotary_emb(query, image_rotary_emb)
185
+ key = apply_rotary_emb(key, image_rotary_emb)
186
+
187
+ batch_size = hidden_states.shape[0]
188
+ img_seq_len = hidden_states.shape[1]
189
+ txt_seq_len = 0
190
+
191
+ if attn.add_q_proj is not None and encoder_hidden_states is not None:
192
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
193
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
194
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
195
+
196
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
197
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
198
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
199
+
200
+ if attn.norm_added_q is not None:
201
+ encoder_query = attn.norm_added_q(encoder_query)
202
+ if attn.norm_added_k is not None:
203
+ encoder_key = attn.norm_added_k(encoder_key)
204
+
205
+ query = torch.cat([query, encoder_query], dim=2)
206
+ key = torch.cat([key, encoder_key], dim=2)
207
+ value = torch.cat([value, encoder_value], dim=2)
208
+
209
+ txt_seq_len = encoder_hidden_states.shape[1]
210
+
211
+ max_seqlen_q = max_seqlen_kv = img_seq_len + txt_seq_len
212
+ cu_seqlens_q = cu_seqlens_kv = get_cu_seqlens(attention_mask)
213
+
214
+ query = query.transpose(1, 2).reshape(-1, query.shape[1], query.shape[3])
215
+ key = key.transpose(1, 2).reshape(-1, key.shape[1], key.shape[3])
216
+ value = value.transpose(1, 2).reshape(-1, value.shape[1], value.shape[3])
217
+
218
+ if self.use_flash_attn:
219
+ hidden_states = flash_attn_varlen_func(
220
+ query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
221
+ )
222
+ elif self.use_sageattn:
223
+ hidden_states = sageattn_varlen(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
224
+ else:
225
+ raise NotImplementedError("Please set use_flash_attn=True or use_sageattn=True")
226
+
227
+ hidden_states = hidden_states.reshape(batch_size, max_seqlen_q, -1)
228
+ hidden_states = hidden_states.to(query.dtype)
229
+
230
+ if encoder_hidden_states is not None:
231
+ hidden_states, encoder_hidden_states = (
232
+ hidden_states[:, : -encoder_hidden_states.shape[1]],
233
+ hidden_states[:, -encoder_hidden_states.shape[1] :],
234
+ )
235
+
236
+ if getattr(attn, "to_out", None) is not None:
237
+ hidden_states = attn.to_out[0](hidden_states)
238
+ hidden_states = attn.to_out[1](hidden_states)
239
+
240
+ if getattr(attn, "to_add_out", None) is not None:
241
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
242
+
243
+ return hidden_states, encoder_hidden_states
244
+
245
+ @torch.inference_mode()
246
+ def call_pipe(
247
+ pipe,
248
+ prompt: Union[str, List[str]] = None,
249
+ prompt_2: Union[str, List[str]] = None,
250
+ height: int = 720,
251
+ width: int = 1280,
252
+ num_frames: int = 129,
253
+ num_inference_steps: int = 50,
254
+ sigmas: List[float] = None,
255
+ guidance_scale: float = 6.0,
256
+ num_videos_per_prompt: Optional[int] = 1,
257
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
258
+ latents: Optional[torch.Tensor] = None,
259
+ prompt_embeds: Optional[torch.Tensor] = None,
260
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
261
+ prompt_attention_mask: Optional[torch.Tensor] = None,
262
+ output_type: Optional[str] = "pil",
263
+ return_dict: bool = True,
264
+ attention_kwargs: Optional[Dict[str, Any]] = None,
265
+ callback_on_step_end: Optional[Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]] = None,
266
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
267
+ prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
268
+ max_sequence_length: int = 256,
269
+ image_latents: Optional[torch.Tensor] = None,
270
+ ):
271
+
272
+
273
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
274
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
275
+
276
+ # 1. Check inputs. Raise error if not correct
277
+ pipe.check_inputs(
278
+ prompt,
279
+ prompt_2,
280
+ height,
281
+ width,
282
+ prompt_embeds,
283
+ callback_on_step_end_tensor_inputs,
284
+ prompt_template,
285
+ )
286
+
287
+ pipe._guidance_scale = guidance_scale
288
+ pipe._attention_kwargs = attention_kwargs
289
+ pipe._current_timestep = None
290
+ pipe._interrupt = False
291
+
292
+ device = pipe._execution_device
293
+
294
+ # 2. Define call parameters
295
+ if prompt is not None and isinstance(prompt, str):
296
+ batch_size = 1
297
+ elif prompt is not None and isinstance(prompt, list):
298
+ batch_size = len(prompt)
299
+ else:
300
+ batch_size = prompt_embeds.shape[0]
301
+
302
+ # 3. Encode input prompt
303
+ prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = pipe.encode_prompt(
304
+ prompt=prompt,
305
+ prompt_2=prompt_2,
306
+ prompt_template=prompt_template,
307
+ num_videos_per_prompt=num_videos_per_prompt,
308
+ prompt_embeds=prompt_embeds,
309
+ pooled_prompt_embeds=pooled_prompt_embeds,
310
+ prompt_attention_mask=prompt_attention_mask,
311
+ device=device,
312
+ max_sequence_length=max_sequence_length,
313
+ )
314
+
315
+ transformer_dtype = pipe.transformer.dtype
316
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
317
+ prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
318
+ if pooled_prompt_embeds is not None:
319
+ pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
320
+
321
+ # 4. Prepare timesteps
322
+ sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
323
+ timesteps, num_inference_steps = retrieve_timesteps(
324
+ pipe.scheduler,
325
+ num_inference_steps,
326
+ device,
327
+ sigmas=sigmas,
328
+ )
329
+
330
+ # 5. Prepare latent variables
331
+ num_channels_latents = pipe.transformer.config.in_channels
332
+ num_latent_frames = (num_frames - 1) // pipe.vae_scale_factor_temporal + 1
333
+ latents = pipe.prepare_latents(
334
+ batch_size * num_videos_per_prompt,
335
+ num_channels_latents,
336
+ height,
337
+ width,
338
+ num_latent_frames,
339
+ torch.float32,
340
+ device,
341
+ generator,
342
+ latents,
343
+ )
344
+
345
+ # 6. Prepare guidance condition
346
+ guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
347
+
348
+ # 7. Denoising loop
349
+ num_warmup_steps = len(timesteps) - num_inference_steps * pipe.scheduler.order
350
+ pipe._num_timesteps = len(timesteps)
351
+ # 20250305 pftq: added to properly offload to CPU, was out of memory otherwise
352
+ pipe.text_encoder.to("cpu")
353
+ pipe.text_encoder_2.to("cpu")
354
+ torch.cuda.empty_cache()
355
+
356
+ with pipe.progress_bar(total=num_inference_steps) as progress_bar:
357
+ for i, t in enumerate(timesteps):
358
+ if pipe.interrupt:
359
+ continue
360
+
361
+ pipe._current_timestep = t
362
+ latent_model_input = latents.to(transformer_dtype)
363
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
364
+
365
+ noise_pred = pipe.transformer(
366
+ hidden_states=torch.cat([latent_model_input, image_latents], dim=1),
367
+ timestep=timestep,
368
+ encoder_hidden_states=prompt_embeds,
369
+ encoder_attention_mask=prompt_attention_mask,
370
+ pooled_projections=pooled_prompt_embeds,
371
+ guidance=guidance,
372
+ attention_kwargs=attention_kwargs,
373
+ return_dict=False,
374
+ )[0]
375
+
376
+ # compute the previous noisy sample x_t -> x_t-1
377
+ latents = pipe.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
378
+
379
+ if callback_on_step_end is not None:
380
+ callback_kwargs = {}
381
+ for k in callback_on_step_end_tensor_inputs:
382
+ callback_kwargs[k] = locals()[k]
383
+ callback_outputs = callback_on_step_end(pipe, i, t, callback_kwargs)
384
+
385
+ latents = callback_outputs.pop("latents", latents)
386
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
387
+
388
+ # call the callback, if provided
389
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipe.scheduler.order == 0):
390
+ progress_bar.update()
391
+ pipe._current_timestep = None
392
+
393
+ if not output_type == "latent":
394
+ latents = latents.to(pipe.vae.dtype) / pipe.vae.config.scaling_factor
395
+ video = pipe.vae.decode(latents, return_dict=False)[0]
396
+ video = pipe.video_processor.postprocess_video(video, output_type=output_type)
397
+ else:
398
+ video = latents
399
+
400
+ # Offload all models
401
+ pipe.maybe_free_model_hooks()
402
+
403
+ if not return_dict:
404
+ return (video,)
405
+
406
+ return HunyuanVideoPipelineOutput(frames=video)
407
+
408
+ #20250305 pftq: customizable bitrate
409
+ # Function to check if FFmpeg is installed
410
+ import subprocess # For FFmpeg functionality
411
+ def is_ffmpeg_installed():
412
+ try:
413
+ subprocess.run(["ffmpeg", "-version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True)
414
+ return True
415
+ except (subprocess.CalledProcessError, FileNotFoundError):
416
+ return False
417
+
418
+ # FFmpeg-based video saving with bitrate control
419
+ def save_video_with_ffmpeg(frames, output_path, fps, bitrate_mbps, metadata_comment=None):
420
+ frames = [np.array(frame) for frame in frames]
421
+ height, width, _ = frames[0].shape
422
+ bitrate = f"{bitrate_mbps}M"
423
+ cmd = [
424
+ "ffmpeg",
425
+ "-y",
426
+ "-f", "rawvideo",
427
+ "-vcodec", "rawvideo",
428
+ "-s", f"{width}x{height}",
429
+ "-pix_fmt", "rgb24",
430
+ "-r", str(fps),
431
+ "-i", "-",
432
+ "-c:v", "libx264",
433
+ "-b:v", bitrate,
434
+ "-pix_fmt", "yuv420p",
435
+ "-preset", "medium",
436
+ ]
437
+
438
+ # Add metadata comment if provided
439
+ if metadata_comment:
440
+ cmd.extend(["-metadata", f"comment={metadata_comment}"])
441
+ cmd.append(output_path)
442
+
443
+ process = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.PIPE)
444
+ for frame in frames:
445
+ process.stdin.write(frame.tobytes())
446
+ process.stdin.close()
447
+ process.wait()
448
+ stderr_output = process.stderr.read().decode()
449
+ if process.returncode != 0:
450
+ print(f"FFmpeg error: {stderr_output}")
451
+ else:
452
+ print(f"Video saved to {output_path} with FFmpeg")
453
+
454
+ # Fallback OpenCV-based video saving
455
+ def save_video_with_opencv(frames, output_path, fps, bitrate_mbps):
456
+ frames = [np.array(frame) for frame in frames]
457
+ height, width, _ = frames[0].shape
458
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
459
+ writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
460
+ # Note: cv2.CAP_PROP_BITRATE is not supported, so bitrate_mbps is ignored
461
+ for frame in frames:
462
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) # Convert RGB to BGR for OpenCV
463
+ writer.write(frame)
464
+ writer.release()
465
+ print(f"Video saved to {output_path} with OpenCV (bitrate control unavailable)")
466
+
467
+ # Wrapper to choose between FFmpeg and OpenCV
468
+ def save_video_with_quality(frames, output_path, fps, bitrate_mbps, metadata_comment=None):
469
+ if is_ffmpeg_installed():
470
+ save_video_with_ffmpeg(frames, output_path, fps, bitrate_mbps, metadata_comment)
471
+ else:
472
+ print("FFmpeg not found. Falling back to OpenCV (bitrate not customizable).")
473
+ save_video_with_opencv(frames, output_path, fps, bitrate_mbps)
474
+
475
+ # Reconstruct command-line with quotes and backslash+linebreak after argument-value pairs
476
+ def reconstruct_command_line(args, argv):
477
+ cmd_parts = [argv[0]] # Start with script name
478
+ args_dict = vars(args) # Convert args to dict
479
+
480
+ i = 1
481
+ while i < len(argv):
482
+ arg = argv[i]
483
+ if arg.startswith("--"):
484
+ key = arg[2:]
485
+ if key in args_dict:
486
+ value = args_dict[key]
487
+ if isinstance(value, bool):
488
+ if value:
489
+ cmd_parts.append(arg) # Boolean flag
490
+ i += 1
491
+ else:
492
+ # Combine argument and value into one part
493
+ if i + 1 < len(argv) and not argv[i + 1].startswith("--"):
494
+ next_val = argv[i + 1]
495
+ if isinstance(value, str):
496
+ cmd_parts.append(f'{arg} "{value}"') # Quote strings
497
+ else:
498
+ cmd_parts.append(f"{arg} {value}") # No quotes for numbers
499
+ i += 2
500
+ else:
501
+ # Handle missing value in argv (use parsed args)
502
+ if isinstance(value, str):
503
+ cmd_parts.append(f'{arg} "{value}"')
504
+ else:
505
+ cmd_parts.append(f"{arg} {value}")
506
+ i += 1
507
+ else:
508
+ i += 1
509
+
510
+ # Build multi-line string with backslash and newline except for the last part
511
+ if len(cmd_parts) > 1:
512
+ result = ""
513
+ for j, part in enumerate(cmd_parts):
514
+ if j < len(cmd_parts) - 1:
515
+ result += part + " \\\n"
516
+ else:
517
+ result += part # No trailing backslash on last part
518
+ return result
519
+ return cmd_parts[0] # Single arg case
520
+
521
+
522
+ # start executing here ###################
523
+ print("Initializing model...")
524
+ transformer_subfolder = "transformer"
525
+ if args.transformer_model_id == "Skywork/SkyReels-V1-Hunyuan-I2V":
526
+ transformer_subfolder = "" # 20250305 pftq: Error otherwise - Skywork/SkyReels-V1-Hunyuan-I2V does not appear to have a file named config.json.
527
+ transformer = HunyuanVideoTransformer3DModel.from_pretrained(args.transformer_model_id, subfolder=transformer_subfolder, torch_dtype=torch.bfloat16)
528
+ pipe = HunyuanVideoPipeline.from_pretrained(args.base_model_id, transformer=transformer, torch_dtype=torch.bfloat16)
529
+
530
+ # Enable memory savings
531
+ pipe.vae.enable_slicing()
532
+ pipe.vae.enable_tiling()
533
+ pipe.enable_model_cpu_offload()
534
+
535
+ # Apply flash attention to all transformer blocks
536
+ if use_sage or use_flash:
537
+ for block in pipe.transformer.transformer_blocks + pipe.transformer.single_transformer_blocks:
538
+ block.attn.processor = HunyuanVideoFlashAttnProcessor(use_flash_attn=use_flash, use_sageattn=use_sage)
539
+
540
+ with torch.no_grad(): # enable image inputs
541
+ initial_input_channels = pipe.transformer.config.in_channels
542
+ new_img_in = HunyuanVideoPatchEmbed(
543
+ patch_size=(pipe.transformer.config.patch_size_t, pipe.transformer.config.patch_size, pipe.transformer.config.patch_size),
544
+ in_chans=pipe.transformer.config.in_channels * 2,
545
+ embed_dim=pipe.transformer.config.num_attention_heads * pipe.transformer.config.attention_head_dim,
546
+ )
547
+ new_img_in = new_img_in.to(pipe.device, dtype=pipe.dtype)
548
+ new_img_in.proj.weight.zero_()
549
+ new_img_in.proj.weight[:, :initial_input_channels].copy_(pipe.transformer.x_embedder.proj.weight)
550
+
551
+ if pipe.transformer.x_embedder.proj.bias is not None:
552
+ new_img_in.proj.bias.copy_(pipe.transformer.x_embedder.proj.bias)
553
+
554
+ pipe.transformer.x_embedder = new_img_in
555
+
556
+ print("Loading lora...")
557
+ lora_state_dict = pipe.lora_state_dict(args.lora_path)
558
+ transformer_lora_state_dict = {f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") and "lora" in k}
559
+ pipe.load_lora_into_transformer(transformer_lora_state_dict, transformer=pipe.transformer, adapter_name="i2v", _pipeline=pipe)
560
+ pipe.set_adapters(["i2v"], adapter_weights=[1.0])
561
+ pipe.fuse_lora(components=["transformer"], lora_scale=1.0, adapter_names=["i2v"])
562
+ pipe.unload_lora_weights()
563
+
564
+ print("Loading images...")
565
+ cond_frame1 = load_image(args.image1)
566
+ cond_frame2 = load_image(args.image2)
567
+
568
+ cond_frame1 = resize_image_to_bucket(cond_frame1, bucket_reso=(args.width, args.height))
569
+ cond_frame2 = resize_image_to_bucket(cond_frame2, bucket_reso=(args.width, args.height))
570
+
571
+ cond_video = np.zeros(shape=(args.num_frames, args.height, args.width, 3))
572
+
573
+ # 20250305 pftq: Optional 3rd-5th frame, sadly doesn't work so easily, needs more code
574
+ cond_frame3 = None
575
+ cond_frame4 = None
576
+ cond_frame5 = None
577
+
578
+ if args.image3 != "":
579
+ cond_frame3 = load_image(args.image3)
580
+ cond_frame3 = resize_image_to_bucket(cond_frame3, bucket_reso=(args.width, args.height))
581
+ if args.image4 !="":
582
+ cond_frame4 = load_image(args.image4)
583
+ cond_frame4 = resize_image_to_bucket(cond_frame4, bucket_reso=(args.width, args.height))
584
+ if args.image5 !="":
585
+ cond_frame5 = load_image(args.image5)
586
+ cond_frame5 = resize_image_to_bucket(cond_frame5, bucket_reso=(args.width, args.height))
587
+
588
+ if args.image5 != "" and args.image4 != "" and args.image3 !="" and args.image2 !="":
589
+ cond_video[0] = np.array(cond_frame1)
590
+ cond_video[args.num_frames//4] = np.array(cond_frame2)
591
+ cond_video[(args.num_frames * 2 )//4] = np.array(cond_frame3)
592
+ cond_video[(args.num_frames * 3 )//4] = np.array(cond_frame4)
593
+ cond_video[args.num_frames -1] = np.array(cond_frame5)
594
+ elif args.image4 != "" and args.image3 !="" and args.image2 !="":
595
+ cond_video[0] = np.array(cond_frame1)
596
+ cond_video[args.num_frames//3] = np.array(cond_frame2)
597
+ cond_video[(args.num_frames * 2 )//3] = np.array(cond_frame3)
598
+ cond_video[args.num_frames -1] = np.array(cond_frame4)
599
+ elif args.image3 != "" and args.image2 !="":
600
+ cond_video[0] = np.array(cond_frame1)
601
+ cond_video[args.num_frames//2] = np.array(cond_frame2)
602
+ cond_video[args.num_frames -1] = np.array(cond_frame3)
603
+ else:
604
+ cond_video[0] = np.array(cond_frame1)
605
+ cond_video[args.num_frames -1] = np.array(cond_frame2)
606
+
607
+ cond_video = torch.from_numpy(cond_video.copy()).permute(0, 3, 1, 2)
608
+ cond_video = torch.stack([video_transforms(x) for x in cond_video], dim=0).unsqueeze(0)
609
+
610
+ with torch.no_grad():
611
+ image_or_video = cond_video.to(device="cuda", dtype=pipe.dtype)
612
+ image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W]
613
+ cond_latents = pipe.vae.encode(image_or_video).latent_dist.sample()
614
+ cond_latents = cond_latents * pipe.vae.config.scaling_factor
615
+ cond_latents = cond_latents.to(dtype=pipe.dtype)
616
+
617
+ for idx in range(args.video_num): # 20250305 pftq: for loop for multiple videos per batch with varying seeds
618
+
619
+ if args.seed == -1 or idx > 0: # 20250305 pftq: seed argument ignored if asking for more than one video
620
+ random.seed(time.time())
621
+ args.seed = int(random.randrange(4294967294))
622
+
623
+ #20250223 pftq: More useful filename and higher customizable bitrate
624
+ from datetime import datetime
625
+ now = datetime.now()
626
+ formatted_time = now.strftime('%Y-%m-%d_%H-%M-%S')
627
+ video_out_file = formatted_time+f"_hunyuankeyframe_{args.width}-{args.num_frames}f_cfg-{args.cfg}_steps-{args.steps}_seed-{args.seed}_{args.prompt[:40].replace('/','')}_{idx}"
628
+ command_line = reconstruct_command_line(args, sys.argv) # 20250307: Store the full command-line used in the mp4 comment with quotes
629
+ #print(f"Command-line received:\n{command_line}")
630
+
631
+ print("Starting video generation #"+str(idx)+" for "+video_out_file)
632
+ video = call_pipe(
633
+ pipe,
634
+ prompt=args.prompt,
635
+ num_frames=args.num_frames,
636
+ num_inference_steps=args.steps,
637
+ image_latents=cond_latents,
638
+ width=args.width,
639
+ height=args.height,
640
+ guidance_scale=args.cfg,
641
+ generator=torch.Generator(device="cuda").manual_seed(args.seed),
642
+ ).frames[0]
643
+
644
+ # 20250305 pftq: Color match with direct MKL and temporal smoothing
645
+ if args.color_match:
646
+ #save_video_with_quality(video, f"{video_out_file}_raw.mp4", args.fps, args.mbps)
647
+ print("Applying color matching to video...")
648
+ from color_matcher import ColorMatcher
649
+ from color_matcher.io_handler import load_img_file
650
+ from color_matcher.normalizer import Normalizer
651
+
652
+ # Load the reference image (image1)
653
+ ref_img = load_img_file(args.image1) # Original load
654
+ cm = ColorMatcher()
655
+ matched_video = []
656
+
657
+ for frame in video:
658
+ frame_rgb = np.array(frame) # Direct PIL to numpy
659
+ matched_frame = cm.transfer(src=frame_rgb, ref=ref_img, method='mkl')
660
+ matched_frame = Normalizer(matched_frame).uint8_norm()
661
+ matched_video.append(matched_frame)
662
+
663
+ video = matched_video
664
+ # END OF COLOR MATCHING
665
+
666
+ print("Saving "+video_out_file)
667
+ #export_to_video(final_video, "output.mp4", fps=24)
668
+ save_video_with_quality(video, f"{video_out_file}.mp4", args.fps, args.mbps, command_line)