LPX55 commited on
Commit
24d345b
·
verified ·
1 Parent(s): 216eb04

Create app_optimized.py

Browse files
Files changed (1) hide show
  1. app_optimized.py +344 -0
app_optimized.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import safetensors.torch
3
+ import torchvision.transforms.v2 as transforms
4
+ import cv2
5
+ import torch
6
+ from torch.utils.bottleneck import BottleNeck
7
+ import numpy as np
8
+ from typing import List, Optional, Tuple, Union
9
+ from PIL import Image
10
+ import io
11
+ from io import BytesIO
12
+ from diffusers import HunyuanVideoPipeline, FlowMatchEulerDiscreteScheduler
13
+ from diffusers.models.transformers.transformer_hunyuan_video import HunyuanVideoPatchEmbed, HunyuanVideoTransformer3DModel
14
+ from diffusers.utils import export_to_video
15
+ from diffusers.models.attention import Attention
16
+ from diffusers.utils.state_dict_utils import convert_state_dict_to_diffusers, convert_unet_state_dict_to_peft
17
+ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
18
+ from diffusers.models.embeddings import apply_rotary_emb
19
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
20
+ from diffusers.loaders import HunyuanVideoLoraLoaderMixin
21
+ from diffusers.models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
22
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
23
+ from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
24
+ from diffusers.utils.torch_utils import randn_tensor
25
+ from diffusers.video_processor import VideoProcessor
26
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
27
+ from diffusers.pipelines.hunyuan_video.pipeline_output import HunyuanVideoPipelineOutput
28
+ from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import retrieve_timesteps, DEFAULT_PROMPT_TEMPLATE
29
+ from diffusers.utils import load_image
30
+ from huggingface_hub import hf_hub_download
31
+ import requests
32
+ import io
33
+ device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ # Define video transformations
35
+ video_transforms = transforms.Compose(
36
+ [
37
+ transforms.Lambda(lambda x: x / 255.0),
38
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
39
+ ]
40
+ )
41
+ model_id = "hunyuanvideo-community/HunyuanVideo"
42
+ lora_path = hf_hub_download("dashtoon/hunyuan-video-keyframe-control-lora", "i2v.sft") # Replace with the actual LORA path
43
+ transformer = HunyuanVideoTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.float16)
44
+ global pipe
45
+ pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16)
46
+
47
+ # Enable memory savings
48
+ pipe.vae.enable_tiling()
49
+ pipe.enable_model_cpu_offload()
50
+
51
+ with torch.no_grad(): # enable image inputs
52
+ initial_input_channels = pipe.transformer.config.in_channels
53
+ new_img_in = HunyuanVideoPatchEmbed(
54
+ patch_size=(pipe.transformer.config.patch_size_t, pipe.transformer.config.patch_size, pipe.transformer.config.patch_size),
55
+ in_chans=pipe.transformer.config.in_channels * 2,
56
+ embed_dim=pipe.transformer.config.num_attention_heads * pipe.transformer.config.attention_head_dim,
57
+ )
58
+ new_img_in = new_img_in.to(pipe.device, dtype=pipe.dtype)
59
+ new_img_in.proj.weight.zero_()
60
+ new_img_in.proj.weight[:, :initial_input_channels].copy_(pipe.transformer.x_embedder.proj.weight)
61
+ if pipe.transformer.x_embedder.proj.bias is not None:
62
+ new_img_in.proj.bias.copy_(pipe.transformer.x_embedder.proj.bias)
63
+ pipe.transformer.x_embedder = new_img_in
64
+
65
+ lora_state_dict = safetensors.torch.load_file(lora_path, device="cpu")
66
+ transformer_lora_state_dict = {f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") and "lora" in k}
67
+ pipe.load_lora_into_transformer(transformer_lora_state_dict, transformer=pipe.transformer, adapter_name="i2v", _pipeline=pipe)
68
+ pipe.set_adapters(["i2v"], adapter_weights=[1.0])
69
+ pipe.fuse_lora(components=["transformer"], lora_scale=1.0, adapter_names=["i2v"])
70
+ pipe.unload_lora_weights()
71
+
72
+ def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: Tuple[int, int]) -> np.ndarray:
73
+ """
74
+ Resize the image to the bucket resolution.
75
+ """
76
+ if isinstance(image, Image.Image):
77
+ image = np.array(image)
78
+ elif not isinstance(image, np.ndarray):
79
+ raise ValueError("Image must be a PIL Image or NumPy array")
80
+
81
+ image_height, image_width = image.shape[:2]
82
+ if bucket_reso == (image_width, image_height):
83
+ return image
84
+ bucket_width, bucket_height = bucket_reso
85
+ scale_width = bucket_width / image_width
86
+ scale_height = bucket_height / image_height
87
+ scale = max(scale_width, scale_height)
88
+ image_width = int(image_width * scale + 0.5)
89
+ image_height = int(image_height * scale + 0.5)
90
+ if scale > 1:
91
+ image = Image.fromarray(image)
92
+ image = image.resize((image_width, image_height), Image.LANCZOS)
93
+ image = np.array(image)
94
+ else:
95
+ image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA)
96
+ # crop the image to the bucket resolution
97
+ crop_left = (image_width - bucket_width) // 2
98
+ crop_top = (image_height - bucket_height) // 2
99
+ image = image[crop_top:crop_top + bucket_height, crop_left:crop_left + bucket_width]
100
+ return image
101
+
102
+
103
+
104
+ def generate_video(prompt: str, frame1: Image.Image, frame2: Image.Image, resolution: str, guidance_scale: float, num_frames: int, num_inference_steps: int, fps: int) -> bytes:
105
+ # Debugging print statements
106
+ print(f"Frame 1 Type: {type(frame1)}")
107
+ print(f"Frame 2 Type: {type(frame2)}")
108
+ print(f"Resolution: {resolution}")
109
+
110
+ # Parse resolution
111
+ width, height = map(int, resolution.split('x'))
112
+
113
+ # Load and preprocess frames
114
+ cond_frame1 = np.array(frame1)
115
+ cond_frame2 = np.array(frame2)
116
+ cond_frame1 = resize_image_to_bucket(cond_frame1, bucket_reso=(width, height))
117
+ cond_frame2 = resize_image_to_bucket(cond_frame2, bucket_reso=(width, height))
118
+ cond_video = np.zeros(shape=(num_frames, height, width, 3))
119
+ cond_video[0], cond_video[-1] = cond_frame1, cond_frame2
120
+ cond_video = torch.from_numpy(cond_video.copy()).permute(0, 3, 1, 2)
121
+ cond_video = torch.stack([video_transforms(x) for x in cond_video], dim=0).unsqueeze(0)
122
+ with torch.no_grad():
123
+ image_or_video = cond_video.to(device="cuda", dtype=pipe.dtype)
124
+ image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W]
125
+ cond_latents = pipe.vae.encode(image_or_video).latent_dist.sample()
126
+ cond_latents = cond_latents * pipe.vae.config.scaling_factor
127
+ cond_latents = cond_latents.to(device=device, dtype=pipe.dtype)
128
+ assert not torch.any(torch.isnan(cond_latents))
129
+ # Generate video
130
+ video = call_pipe(
131
+ pipe,
132
+ prompt=prompt,
133
+ num_frames=num_frames,
134
+ num_inference_steps=num_inference_steps,
135
+ image_latents=cond_latents,
136
+ width=width,
137
+ height=height,
138
+ guidance_scale=guidance_scale,
139
+ generator=torch.Generator(device="cuda").manual_seed(0),
140
+ ).frames[0]
141
+ # Export to video
142
+ video_path = "output.mp4"
143
+ # video_bytes = io.BytesIO()
144
+ export_to_video(video, video_path, fps=fps)
145
+ torch.cuda.empty_cache()
146
+ return video_path
147
+
148
+ @torch.inference_mode()
149
+ def call_pipe(
150
+ pipe,
151
+ prompt: Union[str, List[str]] = None,
152
+ prompt_2: Union[str, List[str]] = None,
153
+ height: int = 720,
154
+ width: int = 1280,
155
+ num_frames: int = 129,
156
+ num_inference_steps: int = 50,
157
+ sigmas: Optional[List[float]] = None,
158
+ guidance_scale: float = 6.0,
159
+ num_videos_per_prompt: Optional[int] = 1,
160
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
161
+ latents: Optional[torch.Tensor] = None,
162
+ prompt_embeds: Optional[torch.Tensor] = None,
163
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
164
+ prompt_attention_mask: Optional[torch.Tensor] = None,
165
+ output_type: Optional[str] = "pil",
166
+ return_dict: bool = True,
167
+ attention_kwargs: Optional[dict] = None,
168
+ callback_on_step_end: Optional[Union[callable, PipelineCallback, MultiPipelineCallbacks]] = None,
169
+ callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
170
+ prompt_template: Optional[dict] = DEFAULT_PROMPT_TEMPLATE,
171
+ max_sequence_length: int = 256,
172
+ image_latents: Optional[torch.Tensor] = None,
173
+ ):
174
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
175
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
176
+
177
+ # 1. Check inputs. Raise error if not correct
178
+ pipe.check_inputs(
179
+ prompt,
180
+ prompt_2,
181
+ height,
182
+ width,
183
+ prompt_embeds,
184
+ callback_on_step_end_tensor_inputs,
185
+ prompt_template,
186
+ )
187
+
188
+ pipe._guidance_scale = guidance_scale
189
+ pipe._attention_kwargs = attention_kwargs
190
+ pipe._current_timestep = None
191
+ pipe._interrupt = False
192
+ device = pipe._execution_device
193
+
194
+ # 2. Define call parameters
195
+ if prompt is not None and isinstance(prompt, str):
196
+ batch_size = 1
197
+ elif prompt is not None and isinstance(prompt, list):
198
+ batch_size = len(prompt)
199
+ else:
200
+ batch_size = prompt_embeds.shape[0]
201
+
202
+ # 3. Encode input prompt
203
+ prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = pipe.encode_prompt(
204
+ prompt=prompt,
205
+ prompt_2=prompt_2,
206
+ prompt_template=prompt_template,
207
+ num_videos_per_prompt=num_videos_per_prompt,
208
+ prompt_embeds=prompt_embeds,
209
+ pooled_prompt_embeds=pooled_prompt_embeds,
210
+ prompt_attention_mask=prompt_attention_mask,
211
+ device=device,
212
+ max_sequence_length=max_sequence_length,
213
+ )
214
+
215
+ transformer_dtype = pipe.transformer.dtype
216
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
217
+ prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
218
+ if pooled_prompt_embeds is not None:
219
+ pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
220
+
221
+ # 4. Prepare timesteps
222
+ sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
223
+ timesteps, num_inference_steps = retrieve_timesteps(
224
+ pipe.scheduler,
225
+ num_inference_steps,
226
+ device,
227
+ sigmas=sigmas,
228
+ )
229
+
230
+ # 5. Prepare latent variables
231
+ num_channels_latents = pipe.transformer.config.in_channels
232
+ num_latent_frames = (num_frames - 1) // pipe.vae_scale_factor_temporal + 1
233
+ latents = pipe.prepare_latents(
234
+ batch_size * num_videos_per_prompt,
235
+ num_channels_latents,
236
+ height,
237
+ width,
238
+ num_latent_frames,
239
+ torch.float32,
240
+ device,
241
+ generator,
242
+ latents,
243
+ )
244
+
245
+ # 6. Prepare guidance condition
246
+ guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
247
+
248
+ # 7. Denoising loop
249
+ num_warmup_steps = len(timesteps) - num_inference_steps * pipe.scheduler.order
250
+ pipe._num_timesteps = len(timesteps)
251
+
252
+ with pipe.progress_bar(total=num_inference_steps) as progress_bar:
253
+ for i, t in enumerate(timesteps):
254
+ if pipe.interrupt:
255
+ continue
256
+ pipe._current_timestep = t
257
+ latent_model_input = latents.to(transformer_dtype)
258
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
259
+ noise_pred = pipe.transformer(
260
+ hidden_states=torch.cat([latent_model_input, image_latents], dim=1),
261
+ timestep=timestep,
262
+ encoder_hidden_states=prompt_embeds,
263
+ encoder_attention_mask=prompt_attention_mask,
264
+ pooled_projections=pooled_prompt_embeds,
265
+ guidance=guidance,
266
+ attention_kwargs=attention_kwargs,
267
+ return_dict=False,
268
+ )[0]
269
+
270
+ # compute the previous noisy sample x_t -> x_t-1
271
+ latents = pipe.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
272
+
273
+ if callback_on_step_end is not None:
274
+ callback_kwargs = {}
275
+ for k in callback_on_step_end_tensor_inputs:
276
+ callback_kwargs[k] = locals()[k]
277
+ callback_outputs = callback_on_step_end(pipe, i, t, callback_kwargs)
278
+ latents = callback_outputs.pop("latents", latents)
279
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
280
+
281
+ # call the callback, if provided
282
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipe.scheduler.order == 0):
283
+ progress_bar.update()
284
+
285
+ pipe._current_timestep = None
286
+ if not output_type == "latent":
287
+ latents = latents.to(pipe.vae.dtype) / pipe.vae.config.scaling_factor
288
+ video = pipe.vae.decode(latents, return_dict=False)[0]
289
+ video = pipe.video_processor.postprocess_video(video, output_type=output_type)
290
+ else:
291
+ video = latents
292
+
293
+ # Offload all models
294
+ pipe.maybe_free_model_hooks()
295
+
296
+ if not return_dict:
297
+ return (video,)
298
+ return HunyuanVideoPipelineOutput(frames=video)
299
+
300
+
301
+ def main():
302
+ gr.Markdown(
303
+ """
304
+ - https://i-bacon.bunkr.ru/11b45aa7-630b-4189-996f-a6b37a697786.png
305
+ - https://i-bacon.bunkr.ru/2382224f-120e-482d-a75d-f1a1bf13038c.png
306
+ """)
307
+ # Define the interface inputs
308
+ inputs = [
309
+ gr.Textbox(label="Prompt", value="a woman"),
310
+ gr.Image(label="Frame 1", type="pil"),
311
+ gr.Image(label="Frame 2", type="pil"),
312
+ gr.Dropdown(
313
+ label="Resolution",
314
+ choices=["720x1280", "544x960", "1280x720", "960x544", "720x720"],
315
+ value="544x960"
316
+ ),
317
+ # gr.Textbox(label="Frame 1 URL", value="https://i-bacon.bunkr.ru/11b45aa7-630b-4189-996f-a6b37a697786.png"),
318
+ # gr.Textbox(label="Frame 2 URL", value="https://i-bacon.bunkr.ru/2382224f-120e-482d-a75d-f1a1bf13038c.png"),
319
+ gr.Slider(minimum=0.1, maximum=20, step=0.1, label="Guidance Scale", value=6.0),
320
+ gr.Slider(minimum=1, maximum=129, step=1, label="Number of Frames", value=49),
321
+ gr.Slider(minimum=1, maximum=100, step=1, label="Number of Inference Steps", value=30),
322
+ gr.Slider(minimum=1, maximum=60, step=1, label="FPS", value=16)
323
+ ]
324
+
325
+ # Define the interface outputs
326
+ outputs = [
327
+ gr.Video(label="Generated Video"),
328
+ ]
329
+
330
+
331
+ # Create the Gradio interface
332
+ iface = gr.Interface(
333
+ fn=generate_video,
334
+ inputs=inputs,
335
+ outputs=outputs,
336
+ title="Hunyuan Video Generator",
337
+ description="Generate videos using the HunyuanVideo model with a prompt and two frames as conditions.",
338
+ )
339
+
340
+ # Launch the Gradio app
341
+ iface.launch(show_error=True)
342
+
343
+ if __name__ == "__main__":
344
+ main()