LPX55 commited on
Commit
5a8cf56
·
verified ·
1 Parent(s): 6e842f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -35
app.py CHANGED
@@ -34,6 +34,35 @@ video_transforms = transforms.Compose(
34
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
35
  ]
36
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: Tuple[int, int]) -> np.ndarray:
39
  """
@@ -65,37 +94,6 @@ def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: T
65
  image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width]
66
  return image
67
 
68
- def construct_video_pipeline(model_id: str, lora_path: str):
69
- # Load model and LORA
70
- transformer = HunyuanVideoTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
71
- pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
72
-
73
- # Enable memory savings
74
- pipe.vae.enable_tiling()
75
- pipe.enable_model_cpu_offload()
76
-
77
- with torch.no_grad(): # enable image inputs
78
- initial_input_channels = pipe.transformer.config.in_channels
79
- new_img_in = HunyuanVideoPatchEmbed(
80
- patch_size=(pipe.transformer.config.patch_size_t, pipe.transformer.config.patch_size, pipe.transformer.config.patch_size),
81
- in_chans=pipe.transformer.config.in_channels * 2,
82
- embed_dim=pipe.transformer.config.num_attention_heads * pipe.transformer.config.attention_head_dim,
83
- )
84
- new_img_in = new_img_in.to(pipe.device, dtype=pipe.dtype)
85
- new_img_in.proj.weight.zero_()
86
- new_img_in.proj.weight[:, :initial_input_channels].copy_(pipe.transformer.x_embedder.proj.weight)
87
- if pipe.transformer.x_embedder.proj.bias is not None:
88
- new_img_in.proj.bias.copy_(pipe.transformer.x_embedder.proj.bias)
89
- pipe.transformer.x_embedder = new_img_in
90
-
91
- lora_state_dict = safetensors.torch.load_file(lora_path, device="cpu")
92
- transformer_lora_state_dict = {f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") and "lora" in k}
93
- pipe.load_lora_into_transformer(transformer_lora_state_dict, transformer=pipe.transformer, adapter_name="i2v", _pipeline=pipe)
94
- pipe.set_adapters(["i2v"], adapter_weights=[1.0])
95
- pipe.fuse_lora(components=["transformer"], lora_scale=1.0, adapter_names=["i2v"])
96
- pipe.unload_lora_weights()
97
-
98
- return pipe
99
 
100
  def generate_video(pipe, prompt: str, frame1_path: str, frame2_path: str, guidance_scale: float, num_frames: int, num_inference_steps: int) -> bytes:
101
  # Load and preprocess frames
@@ -317,13 +315,11 @@ def main():
317
  outputs = [
318
  gr.Video(label="Generated Video"),
319
  ]
320
- # Create a wrapper function to pass the pre-initialized pipeline
321
- def generate_video_wrapper(*args):
322
- return generate_video(pipe, *args)
323
 
324
  # Create the Gradio interface
325
  iface = gr.Interface(
326
- fn=generate_video_wrapper,
327
  inputs=inputs,
328
  outputs=outputs,
329
  title="Hunyuan Video Generator",
 
34
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
35
  ]
36
  )
37
+ model_id = "hunyuanvideo-community/HunyuanVideo"
38
+ lora_path = hf_hub_download("dashtoon/hunyuan-video-keyframe-control-lora", "i2v.sft") # Replace with the actual LORA path
39
+ transformer = HunyuanVideoTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
40
+ pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
41
+
42
+ # Enable memory savings
43
+ pipe.vae.enable_tiling()
44
+ pipe.enable_model_cpu_offload()
45
+
46
+ with torch.no_grad(): # enable image inputs
47
+ initial_input_channels = pipe.transformer.config.in_channels
48
+ new_img_in = HunyuanVideoPatchEmbed(
49
+ patch_size=(pipe.transformer.config.patch_size_t, pipe.transformer.config.patch_size, pipe.transformer.config.patch_size),
50
+ in_chans=pipe.transformer.config.in_channels * 2,
51
+ embed_dim=pipe.transformer.config.num_attention_heads * pipe.transformer.config.attention_head_dim,
52
+ )
53
+ new_img_in = new_img_in.to(pipe.device, dtype=pipe.dtype)
54
+ new_img_in.proj.weight.zero_()
55
+ new_img_in.proj.weight[:, :initial_input_channels].copy_(pipe.transformer.x_embedder.proj.weight)
56
+ if pipe.transformer.x_embedder.proj.bias is not None:
57
+ new_img_in.proj.bias.copy_(pipe.transformer.x_embedder.proj.bias)
58
+ pipe.transformer.x_embedder = new_img_in
59
+
60
+ lora_state_dict = safetensors.torch.load_file(lora_path, device="cpu")
61
+ transformer_lora_state_dict = {f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") and "lora" in k}
62
+ pipe.load_lora_into_transformer(transformer_lora_state_dict, transformer=pipe.transformer, adapter_name="i2v", _pipeline=pipe)
63
+ pipe.set_adapters(["i2v"], adapter_weights=[1.0])
64
+ pipe.fuse_lora(components=["transformer"], lora_scale=1.0, adapter_names=["i2v"])
65
+ pipe.unload_lora_weights()
66
 
67
  def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: Tuple[int, int]) -> np.ndarray:
68
  """
 
94
  image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width]
95
  return image
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  def generate_video(pipe, prompt: str, frame1_path: str, frame2_path: str, guidance_scale: float, num_frames: int, num_inference_steps: int) -> bytes:
99
  # Load and preprocess frames
 
315
  outputs = [
316
  gr.Video(label="Generated Video"),
317
  ]
318
+
 
 
319
 
320
  # Create the Gradio interface
321
  iface = gr.Interface(
322
+ fn=generate_video,
323
  inputs=inputs,
324
  outputs=outputs,
325
  title="Hunyuan Video Generator",