Maximofn commited on
Commit
67c95e3
·
1 Parent(s): f179ff2

feat(src): :rocket: Update code with diffusers info

Browse files
Files changed (2) hide show
  1. app.py +40 -41
  2. requirements.txt +3 -1
app.py CHANGED
@@ -1,31 +1,37 @@
 
 
 
1
  import os
2
  import time
3
- from pathlib import Path
4
  from datetime import datetime
5
  import gradio as gr
6
- import random
7
- import os
8
 
9
- from hyvideo.utils.file_utils import save_videos_grid
10
  from hyvideo.config import parse_args
11
- from hyvideo.inference import HunyuanVideoSampler
12
- from hyvideo.constants import NEGATIVE_PROMPT
13
 
14
- def initialize_model(model_path):
15
- args = parse_args()
16
- # models_root_path = Path(model_path)
17
- # if not models_root_path.exists():
18
- # raise ValueError(f"`models_root` not exists: {models_root_path}")
19
-
20
- hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(
21
- model_path,
22
- args=args,
23
- device_map="auto"
 
 
 
 
 
 
 
 
24
  )
25
- return hunyuan_video_sampler
 
26
 
27
  def generate_video(
28
- model,
29
  prompt,
30
  resolution,
31
  video_length,
@@ -38,38 +44,32 @@ def generate_video(
38
  seed = None if seed == -1 else seed
39
  width, height = resolution.split("x")
40
  width, height = int(width), int(height)
41
- negative_prompt = "" # not applicable in the inference
42
-
43
- outputs = model.predict(
44
  prompt=prompt,
45
  height=height,
46
- width=width,
47
- video_length=video_length,
48
- seed=seed,
49
- negative_prompt=negative_prompt,
50
- infer_steps=num_inference_steps,
51
  guidance_scale=guidance_scale,
52
- num_videos_per_prompt=1,
53
- flow_shift=flow_shift,
54
- batch_size=1,
55
- embedded_guidance_scale=embedded_guidance_scale
56
- )
57
-
58
- samples = outputs['samples']
59
- sample = samples[0].unsqueeze(0)
60
 
 
61
  save_path = os.path.join(os.getcwd(), "gradio_outputs")
62
  os.makedirs(save_path, exist_ok=True)
63
 
64
  time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
65
- video_path = f"{save_path}/{time_flag}_seed{outputs['seeds'][0]}_{outputs['prompts'][0][:100].replace('/','')}.mp4"
66
- save_videos_grid(sample, video_path, fps=24)
 
 
67
  print(f'Sample saved to: {video_path}')
68
 
69
  return video_path
70
 
71
- def create_demo(model_path, save_path):
72
- model = initialize_model(model_path)
73
 
74
  with gr.Blocks() as demo:
75
  gr.Markdown("# Hunyuan Video Generation")
@@ -119,7 +119,7 @@ def create_demo(model_path, save_path):
119
  output = gr.Video(label="Generated Video")
120
 
121
  generate_btn.click(
122
- fn=lambda *inputs: generate_video(model, *inputs),
123
  inputs=[
124
  prompt,
125
  resolution,
@@ -141,7 +141,6 @@ if __name__ == "__main__":
141
  server_name = os.getenv("SERVER_NAME", "0.0.0.0")
142
  server_port = int(os.getenv("SERVER_PORT", "8081"))
143
  args = parse_args()
144
- print(args)
145
- model = "tencent/HunyuanVideo"
146
  demo = create_demo(model, args.save_path)
147
  demo.launch(server_name=server_name, server_port=server_port)
 
1
+ import torch
2
+ from diffusers import BitsAndBytesConfig, HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
3
+
4
  import os
5
  import time
 
6
  from datetime import datetime
7
  import gradio as gr
 
 
8
 
 
9
  from hyvideo.config import parse_args
 
 
10
 
11
+
12
+ def initialize_model(model):
13
+ quant_config = BitsAndBytesConfig(load_in_8bit=True)
14
+
15
+ transformer_8bit = HunyuanVideoTransformer3DModel.from_pretrained(
16
+ model,
17
+ subfolder="transformer",
18
+ quantization_config=quant_config,
19
+ torch_dtype=torch.bfloat16,
20
+ device_map="balanced",
21
+ )
22
+
23
+ # Cargar el pipeline
24
+ pipeline = HunyuanVideoPipeline.from_pretrained(
25
+ model,
26
+ transformer=transformer_8bit,
27
+ torch_dtype=torch.float16,
28
+ device_map="balanced",
29
  )
30
+
31
+ return pipeline
32
 
33
  def generate_video(
34
+ pipeline,
35
  prompt,
36
  resolution,
37
  video_length,
 
44
  seed = None if seed == -1 else seed
45
  width, height = resolution.split("x")
46
  width, height = int(width), int(height)
47
+
48
+ # Generar el video usando el pipeline
49
+ video = pipeline(
50
  prompt=prompt,
51
  height=height,
52
+ width=width,
53
+ num_frames=video_length,
54
+ num_inference_steps=num_inference_steps,
 
 
55
  guidance_scale=guidance_scale,
56
+ ).frames[0]
 
 
 
 
 
 
 
57
 
58
+ # Guardar el video
59
  save_path = os.path.join(os.getcwd(), "gradio_outputs")
60
  os.makedirs(save_path, exist_ok=True)
61
 
62
  time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
63
+ video_path = f"{save_path}/{time_flag}_seed{seed}_{prompt[:100].replace('/','')}.mp4"
64
+
65
+ from diffusers.utils import export_to_video
66
+ export_to_video(video, video_path, fps=24)
67
  print(f'Sample saved to: {video_path}')
68
 
69
  return video_path
70
 
71
+ def create_demo(model, save_path):
72
+ pipeline = initialize_model(model)
73
 
74
  with gr.Blocks() as demo:
75
  gr.Markdown("# Hunyuan Video Generation")
 
119
  output = gr.Video(label="Generated Video")
120
 
121
  generate_btn.click(
122
+ fn=lambda *inputs: generate_video(pipeline, *inputs),
123
  inputs=[
124
  prompt,
125
  resolution,
 
141
  server_name = os.getenv("SERVER_NAME", "0.0.0.0")
142
  server_port = int(os.getenv("SERVER_PORT", "8081"))
143
  args = parse_args()
144
+ model = "hunyuanvideo-community/HunyuanVideo" # Actualizado el path del modelo
 
145
  demo = create_demo(model, args.save_path)
146
  demo.launch(server_name=server_name, server_port=server_port)
requirements.txt CHANGED
@@ -2,7 +2,8 @@ torch==2.4.0
2
  torchvision==0.19.0
3
  torchaudio==2.4.0
4
  opencv-python==4.9.0.80
5
- diffusers==0.31.0
 
6
  transformers==4.46.3
7
  tokenizers==0.20.3
8
  accelerate==1.1.1
@@ -15,6 +16,7 @@ imageio==2.34.0
15
  imageio-ffmpeg==0.5.1
16
  safetensors==0.4.3
17
  gradio==5.0.0
 
18
  # ninja
19
  # git+https://github.com/Dao-AILab/[email protected]
20
  # xfuser==0.4.0
 
2
  torchvision==0.19.0
3
  torchaudio==2.4.0
4
  opencv-python==4.9.0.80
5
+ # diffusers==0.31.0
6
+ git+https://github.com/huggingface/diffusers
7
  transformers==4.46.3
8
  tokenizers==0.20.3
9
  accelerate==1.1.1
 
16
  imageio-ffmpeg==0.5.1
17
  safetensors==0.4.3
18
  gradio==5.0.0
19
+ bitsandbytes
20
  # ninja
21
  # git+https://github.com/Dao-AILab/[email protected]
22
  # xfuser==0.4.0