LPX55 commited on
Commit
a8d7f39
·
verified ·
1 Parent(s): 10f65ab

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +330 -0
app.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import safetensors.torch
3
+ import torchvision.transforms.v2 as transforms
4
+ import cv2
5
+ import torch
6
+ import numpy as np
7
+ from typing import List, Optional, Tuple, Union
8
+ from PIL import Image
9
+ from diffusers import HunyuanVideoPipeline, FlowMatchEulerDiscreteScheduler
10
+ from diffusers.models.transformers.transformer_hunyuan_video import HunyuanVideoPatchEmbed, HunyuanVideoTransformer3DModel
11
+ from diffusers.utils import export_to_video
12
+ from diffusers.models.attention import Attention
13
+ from diffusers.utils.state_dict_utils import convert_state_dict_to_diffusers, convert_unet_state_dict_to_peft
14
+ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
15
+ from diffusers.models.embeddings import apply_rotary_emb
16
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
17
+ from diffusers.loaders import HunyuanVideoLoraLoaderMixin
18
+ from diffusers.models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
19
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
20
+ from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
21
+ from diffusers.utils.torch_utils import randn_tensor
22
+ from diffusers.video_processor import VideoProcessor
23
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
24
+ from diffusers.pipelines.hunyuan_video.pipeline_output import HunyuanVideoPipelineOutput
25
+ from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import retrieve_timesteps, DEFAULT_PROMPT_TEMPLATE
26
+ from diffusers.utils import load_image
27
+ from huggingface_hub import hf_hub_download
28
+ import requests
29
+
30
+ # Define video transformations
31
+ video_transforms = transforms.Compose(
32
+ [
33
+ transforms.Lambda(lambda x: x / 255.0),
34
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
35
+ ]
36
+ )
37
+
38
+ def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: Tuple[int, int]) -> np.ndarray:
39
+ """
40
+ Resize the image to the bucket resolution.
41
+ """
42
+ is_pil_image = isinstance(image, Image.Image)
43
+ if is_pil_image:
44
+ image_width, image_height = image.size
45
+ else:
46
+ image_height, image_width = image.shape[:2]
47
+ if bucket_reso == (image_width, image_height):
48
+ return np.array(image) if is_pil_image else image
49
+ bucket_width, bucket_height = bucket_reso
50
+ scale_width = bucket_width / image_width
51
+ scale_height = bucket_height / image_height
52
+ scale = max(scale_width, scale_height)
53
+ image_width = int(image_width * scale + 0.5)
54
+ image_height = int(image_height * scale + 0.5)
55
+ if scale > 1:
56
+ image = Image.fromarray(image) if not is_pil_image else image
57
+ image = image.resize((image_width, image_height), Image.LANCZOS)
58
+ image = np.array(image)
59
+ else:
60
+ image = np.array(image) if is_pil_image else image
61
+ image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA)
62
+ # crop the image to the bucket resolution
63
+ crop_left = (image_width - bucket_width) // 2
64
+ crop_top = (image_height - bucket_height) // 2
65
+ image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width]
66
+ return image
67
+
68
+ def construct_video_pipeline(model_id: str, lora_path: str):
69
+ # Load model and LORA
70
+ transformer = HunyuanVideoTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
71
+ pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
72
+
73
+ # Enable memory savings
74
+ pipe.vae.enable_tiling()
75
+ pipe.enable_model_cpu_offload()
76
+
77
+ with torch.no_grad(): # enable image inputs
78
+ initial_input_channels = pipe.transformer.config.in_channels
79
+ new_img_in = HunyuanVideoPatchEmbed(
80
+ patch_size=(pipe.transformer.config.patch_size_t, pipe.transformer.config.patch_size, pipe.transformer.config.patch_size),
81
+ in_chans=pipe.transformer.config.in_channels * 2,
82
+ embed_dim=pipe.transformer.config.num_attention_heads * pipe.transformer.config.attention_head_dim,
83
+ )
84
+ new_img_in = new_img_in.to(pipe.device, dtype=pipe.dtype)
85
+ new_img_in.proj.weight.zero_()
86
+ new_img_in.proj.weight[:, :initial_input_channels].copy_(pipe.transformer.x_embedder.proj.weight)
87
+ if pipe.transformer.x_embedder.proj.bias is not None:
88
+ new_img_in.proj.bias.copy_(pipe.transformer.x_embedder.proj.bias)
89
+ pipe.transformer.x_embedder = new_img_in
90
+
91
+ lora_state_dict = safetensors.torch.load_file(lora_path, device="cpu")
92
+ transformer_lora_state_dict = {f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") and "lora" in k}
93
+ pipe.load_lora_into_transformer(transformer_lora_state_dict, transformer=pipe.transformer, adapter_name="i2v", _pipeline=pipe)
94
+ pipe.set_adapters(["i2v"], adapter_weights=[1.0])
95
+ pipe.fuse_lora(components=["transformer"], lora_scale=1.0, adapter_names=["i2v"])
96
+ pipe.unload_lora_weights()
97
+
98
+ return pipe
99
+
100
+ def generate_video(prompt: str, frame1_url: str, frame2_url: str, guidance_scale: float, num_frames: int, num_inference_steps: int) -> bytes:
101
+ # Load and preprocess frames
102
+ cond_frame1 = Image.open(requests.get(frame1_url, stream=True).raw)
103
+ cond_frame2 = Image.open(requests.get(frame2_url, stream=True).raw)
104
+
105
+ height, width = 720, 1280
106
+ cond_frame1 = resize_image_to_bucket(cond_frame1, bucket_reso=(width, height))
107
+ cond_frame2 = resize_image_to_bucket(cond_frame2, bucket_reso=(width, height))
108
+
109
+ cond_video = np.zeros(shape=(num_frames, height, width, 3))
110
+ cond_video[0], cond_video[-1] = np.array(cond_frame1), np.array(cond_frame2)
111
+ cond_video = torch.from_numpy(cond_video.copy()).permute(0, 3, 1, 2)
112
+ cond_video = torch.stack([video_transforms(x) for x in cond_video], dim=0).unsqueeze(0)
113
+
114
+ # Initialize pipeline
115
+ model_id = "hunyuanvideo-community/HunyuanVideo"
116
+ lora_path = hf_hub_download("dashtoon/hunyuan-video-keyframe-control-lora", "i2v.sft") # Replace with the actual LORA path
117
+ pipe = construct_video_pipeline(model_id, lora_path)
118
+
119
+ with torch.no_grad():
120
+ image_or_video = cond_video.to(device="cuda", dtype=pipe.dtype)
121
+ image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W]
122
+ cond_latents = pipe.vae.encode(image_or_video).latent_dist.sample()
123
+ cond_latents = cond_latents * pipe.vae.config.scaling_factor
124
+ cond_latents = cond_latents.to(dtype=pipe.dtype)
125
+ assert not torch.any(torch.isnan(cond_latents))
126
+
127
+ # Generate video
128
+ video = call_pipe(
129
+ pipe,
130
+ prompt=prompt,
131
+ num_frames=num_frames,
132
+ num_inference_steps=num_inference_steps,
133
+ image_latents=cond_latents,
134
+ width=width,
135
+ height=height,
136
+ guidance_scale=guidance_scale,
137
+ generator=torch.Generator(device="cuda").manual_seed(0),
138
+ ).frames[0]
139
+
140
+ # Export to video
141
+ video_path = "output.mp4"
142
+ export_to_video(video, video_path, fps=24)
143
+
144
+ with open(video_path, "rb") as video_file:
145
+ video_bytes = video_file.read()
146
+
147
+ return video_bytes
148
+
149
+ @torch.inference_mode()
150
+ def call_pipe(
151
+ pipe,
152
+ prompt: Union[str, List[str]] = None,
153
+ prompt_2: Union[str, List[str]] = None,
154
+ height: int = 720,
155
+ width:<|im_start|>: int = 1280,
156
+ num_frames: int = 129,
157
+ num_inference_steps: int = 50,
158
+ sigmas: Optional[List[float]] = None,
159
+ guidance_scale: float = 6.0,
160
+ num_videos_per_prompt: Optional[int] = 1,
161
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
162
+ latents: Optional[torch.Tensor] = None,
163
+ prompt_embeds: Optional[torch.Tensor] = None,
164
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
165
+ prompt_attention_mask: Optional[torch.Tensor] = None,
166
+ output_type: Optional[str] = "pil",
167
+ return_dict: bool = True,
168
+ attention_kwargs: Optional[dict] = None,
169
+ callback_on_step_end: Optional[Union[callable, PipelineCallback, MultiPipelineCallbacks]] = None,
170
+ callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
171
+ prompt_template: Optional[dict] = DEFAULT_PROMPT_TEMPLATE,
172
+ max_sequence_length: int = 256,
173
+ image_latents: Optional[torch.Tensor] = None,
174
+ ):
175
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
176
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
177
+
178
+ # 1. Check inputs. Raise error if not correct
179
+ pipe.check_inputs(
180
+ prompt,
181
+ prompt_2,
182
+ height,
183
+ width,
184
+ prompt_embeds,
185
+ callback_on_step_end_tensor_inputs,
186
+ prompt_template,
187
+ )
188
+
189
+ pipe._guidance_scale = guidance_scale
190
+ pipe._attention_kwargs = attention_kwargs
191
+ pipe._current_timestep = None
192
+ pipe._interrupt = False
193
+ device = pipe._execution_device
194
+
195
+ # 2. Define call parameters
196
+ if prompt is not None and isinstance(prompt, str):
197
+ batch_size = 1
198
+ elif prompt is not None and isinstance(prompt, list):
199
+ batch_size = len(prompt)
200
+ else:
201
+ batch_size = prompt_embeds.shape[0]
202
+
203
+ # 3. Encode input prompt
204
+ prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = pipe.encode_prompt(
205
+ prompt=prompt,
206
+ prompt_2=prompt_2,
207
+ prompt_template=prompt_template,
208
+ num_videos_per_prompt=num_videos_per_prompt,
209
+ prompt_embeds=prompt_embeds,
210
+ pooled_prompt_embeds=pooled_prompt_embeds,
211
+ prompt_attention_mask=prompt_attention_mask,
212
+ device=device,
213
+ max_sequence_length=max_sequence_length,
214
+ )
215
+
216
+ transformer_dtype = pipe.transformer.dtype
217
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
218
+ prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
219
+ if pooled_prompt_embeds is not None:
220
+ pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
221
+
222
+ # 4. Prepare timesteps
223
+ sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
224
+ timesteps, num_inference_steps = retrieve_timesteps(
225
+ pipe.scheduler,
226
+ num_inference_steps,
227
+ device,
228
+ sigmas=sigmas,
229
+ )
230
+
231
+ # 5. Prepare latent variables
232
+ num_channels_latents = pipe.transformer.config.in_channels
233
+ num_latent_frames = (num_frames - 1) // pipe.vae_scale_factor_temporal + 1
234
+ latents = pipe.prepare_latents(
235
+ batch_size * num_videos_per_prompt,
236
+ num_channels_latents,
237
+ height,
238
+ width,
239
+ num_latent_frames,
240
+ torch.float32,
241
+ device,
242
+ generator,
243
+ latents,
244
+ )
245
+
246
+ # 6. Prepare guidance condition
247
+ guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
248
+
249
+ # 7. Denoising loop
250
+ num_warmup_steps = len(timesteps) - num_inference_steps * pipe.scheduler.order
251
+ pipe._num_timesteps = len(timesteps)
252
+
253
+ with pipe.progress_bar(total=num_inference_steps) as progress_bar:
254
+ for i, t in enumerate(timesteps):
255
+ if pipe.interrupt:
256
+ continue
257
+ pipe._current_timestep = t
258
+ latent_model_input = latents.to(transformer_dtype)
259
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
260
+ noise_pred = pipe.transformer(
261
+ hidden_states=torch.cat([latent_model_input, image_latents], dim=1),
262
+ timestep=timestep,
263
+ encoder_hidden_states=prompt_embeds,
264
+ encoder_attention_mask=prompt_attention_mask,
265
+ pooled_projections=pooled_prompt_embeds,
266
+ guidance=guidance,
267
+ attention_kwargs=attention_kwargs,
268
+ return_dict=False,
269
+ )[0]
270
+
271
+ # compute the previous noisy sample x_t -> x_t-1
272
+ latents = pipe.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
273
+
274
+ if callback_on_step_end is not None:
275
+ callback_kwargs = {}
276
+ for k in callback_on_step_end_tensor_inputs:
277
+ callback_kwargs[k] = locals()[k]
278
+ callback_outputs = callback_on_step_end(pipe, i, t, callback_kwargs)
279
+ latents = callback_outputs.pop("latents", latents)
280
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
281
+
282
+ # call the callback, if provided
283
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipe.scheduler.order == 0):
284
+ progress_bar.update()
285
+
286
+ pipe._current_timestep = None
287
+ if not output_type == "latent":
288
+ latents = latents.to(pipe.vae.dtype) / pipe.vae.config.scaling_factor
289
+ video = pipe.vae.decode(latents, return_dict=False)[0]
290
+ video = pipe.video_processor.postprocess_video(video, output_type=output_type)
291
+ else:
292
+ video = latents
293
+
294
+ # Offload all models
295
+ pipe.maybe_free_model_hooks()
296
+
297
+ if not return_dict:
298
+ return (video,)
299
+ return HunyuanVideoPipelineOutput(frames=video)
300
+
301
+ def main():
302
+ # Define the interface inputs
303
+ inputs = [
304
+ gr.Textbox(label="Prompt", value="a woman"),
305
+ gr.Textbox(label="Frame 1 URL", value="https://content.dashtoon.ai/stability-images/e524013d-55d4-483a-b80a-dfc51d639158.png"),
306
+ gr.Textbox(label="Frame 2 URL", value="https://content.dashtoon.ai/stability-images/0b29c296-0a90-4b92-96b9-1ed0ae21e480.png"),
307
+ gr.Slider(minimum=0.1, maximum=20, step=0.1, label="Guidance Scale", value=6.0),
308
+ gr.Slider(minimum=1, maximum=129, step=1, label="Number of Frames", value=77),
309
+ gr.Slider(minimum=1, maximum=100, step=1, label="Number of Inference Steps", value=50)
310
+ ]
311
+
312
+ # Define the interface outputs
313
+ outputs = [
314
+ gr.Video(label="Generated Video"),
315
+ ]
316
+
317
+ # Create the Gradio interface
318
+ iface = gr.Interface(
319
+ fn=generate_video,
320
+ inputs=inputs,
321
+ outputs=outputs,
322
+ title="Hunyuan Video Generator",
323
+ description="Generate videos using the HunyuanVideo model with a prompt and two frames as conditions.",
324
+ )
325
+
326
+ # Launch the Gradio app
327
+ iface.launch()
328
+
329
+ if __name__ == "__main__":
330
+ main()