Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -34,6 +34,35 @@ video_transforms = transforms.Compose(
|
|
34 |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
35 |
]
|
36 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: Tuple[int, int]) -> np.ndarray:
|
39 |
"""
|
@@ -65,37 +94,6 @@ def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: T
|
|
65 |
image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width]
|
66 |
return image
|
67 |
|
68 |
-
def construct_video_pipeline(model_id: str, lora_path: str):
|
69 |
-
# Load model and LORA
|
70 |
-
transformer = HunyuanVideoTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
|
71 |
-
pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
|
72 |
-
|
73 |
-
# Enable memory savings
|
74 |
-
pipe.vae.enable_tiling()
|
75 |
-
pipe.enable_model_cpu_offload()
|
76 |
-
|
77 |
-
with torch.no_grad(): # enable image inputs
|
78 |
-
initial_input_channels = pipe.transformer.config.in_channels
|
79 |
-
new_img_in = HunyuanVideoPatchEmbed(
|
80 |
-
patch_size=(pipe.transformer.config.patch_size_t, pipe.transformer.config.patch_size, pipe.transformer.config.patch_size),
|
81 |
-
in_chans=pipe.transformer.config.in_channels * 2,
|
82 |
-
embed_dim=pipe.transformer.config.num_attention_heads * pipe.transformer.config.attention_head_dim,
|
83 |
-
)
|
84 |
-
new_img_in = new_img_in.to(pipe.device, dtype=pipe.dtype)
|
85 |
-
new_img_in.proj.weight.zero_()
|
86 |
-
new_img_in.proj.weight[:, :initial_input_channels].copy_(pipe.transformer.x_embedder.proj.weight)
|
87 |
-
if pipe.transformer.x_embedder.proj.bias is not None:
|
88 |
-
new_img_in.proj.bias.copy_(pipe.transformer.x_embedder.proj.bias)
|
89 |
-
pipe.transformer.x_embedder = new_img_in
|
90 |
-
|
91 |
-
lora_state_dict = safetensors.torch.load_file(lora_path, device="cpu")
|
92 |
-
transformer_lora_state_dict = {f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") and "lora" in k}
|
93 |
-
pipe.load_lora_into_transformer(transformer_lora_state_dict, transformer=pipe.transformer, adapter_name="i2v", _pipeline=pipe)
|
94 |
-
pipe.set_adapters(["i2v"], adapter_weights=[1.0])
|
95 |
-
pipe.fuse_lora(components=["transformer"], lora_scale=1.0, adapter_names=["i2v"])
|
96 |
-
pipe.unload_lora_weights()
|
97 |
-
|
98 |
-
return pipe
|
99 |
|
100 |
def generate_video(pipe, prompt: str, frame1_path: str, frame2_path: str, guidance_scale: float, num_frames: int, num_inference_steps: int) -> bytes:
|
101 |
# Load and preprocess frames
|
@@ -317,13 +315,11 @@ def main():
|
|
317 |
outputs = [
|
318 |
gr.Video(label="Generated Video"),
|
319 |
]
|
320 |
-
|
321 |
-
def generate_video_wrapper(*args):
|
322 |
-
return generate_video(pipe, *args)
|
323 |
|
324 |
# Create the Gradio interface
|
325 |
iface = gr.Interface(
|
326 |
-
fn=
|
327 |
inputs=inputs,
|
328 |
outputs=outputs,
|
329 |
title="Hunyuan Video Generator",
|
|
|
34 |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
35 |
]
|
36 |
)
|
37 |
+
model_id = "hunyuanvideo-community/HunyuanVideo"
|
38 |
+
lora_path = hf_hub_download("dashtoon/hunyuan-video-keyframe-control-lora", "i2v.sft") # Replace with the actual LORA path
|
39 |
+
transformer = HunyuanVideoTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
|
40 |
+
pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
|
41 |
+
|
42 |
+
# Enable memory savings
|
43 |
+
pipe.vae.enable_tiling()
|
44 |
+
pipe.enable_model_cpu_offload()
|
45 |
+
|
46 |
+
with torch.no_grad(): # enable image inputs
|
47 |
+
initial_input_channels = pipe.transformer.config.in_channels
|
48 |
+
new_img_in = HunyuanVideoPatchEmbed(
|
49 |
+
patch_size=(pipe.transformer.config.patch_size_t, pipe.transformer.config.patch_size, pipe.transformer.config.patch_size),
|
50 |
+
in_chans=pipe.transformer.config.in_channels * 2,
|
51 |
+
embed_dim=pipe.transformer.config.num_attention_heads * pipe.transformer.config.attention_head_dim,
|
52 |
+
)
|
53 |
+
new_img_in = new_img_in.to(pipe.device, dtype=pipe.dtype)
|
54 |
+
new_img_in.proj.weight.zero_()
|
55 |
+
new_img_in.proj.weight[:, :initial_input_channels].copy_(pipe.transformer.x_embedder.proj.weight)
|
56 |
+
if pipe.transformer.x_embedder.proj.bias is not None:
|
57 |
+
new_img_in.proj.bias.copy_(pipe.transformer.x_embedder.proj.bias)
|
58 |
+
pipe.transformer.x_embedder = new_img_in
|
59 |
+
|
60 |
+
lora_state_dict = safetensors.torch.load_file(lora_path, device="cpu")
|
61 |
+
transformer_lora_state_dict = {f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") and "lora" in k}
|
62 |
+
pipe.load_lora_into_transformer(transformer_lora_state_dict, transformer=pipe.transformer, adapter_name="i2v", _pipeline=pipe)
|
63 |
+
pipe.set_adapters(["i2v"], adapter_weights=[1.0])
|
64 |
+
pipe.fuse_lora(components=["transformer"], lora_scale=1.0, adapter_names=["i2v"])
|
65 |
+
pipe.unload_lora_weights()
|
66 |
|
67 |
def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: Tuple[int, int]) -> np.ndarray:
|
68 |
"""
|
|
|
94 |
image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width]
|
95 |
return image
|
96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
def generate_video(pipe, prompt: str, frame1_path: str, frame2_path: str, guidance_scale: float, num_frames: int, num_inference_steps: int) -> bytes:
|
99 |
# Load and preprocess frames
|
|
|
315 |
outputs = [
|
316 |
gr.Video(label="Generated Video"),
|
317 |
]
|
318 |
+
|
|
|
|
|
319 |
|
320 |
# Create the Gradio interface
|
321 |
iface = gr.Interface(
|
322 |
+
fn=generate_video,
|
323 |
inputs=inputs,
|
324 |
outputs=outputs,
|
325 |
title="Hunyuan Video Generator",
|