Spaces:
Runtime error
Runtime error
import spaces | |
import torch | |
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel | |
import os | |
import time | |
from datetime import datetime | |
import gradio as gr | |
def generate_video( | |
prompt, | |
resolution, | |
video_length, | |
seed, | |
num_inference_steps, | |
guidance_scale, | |
flow_shift, # TODO: change to flow_shift | |
embedded_guidance_scale # TODO: change to embedded_guidance_scale | |
): | |
seed = None if seed == -1 else seed | |
width, height = resolution.split("x") | |
width, height = int(width), int(height) | |
model = "hunyuanvideo-community/HunyuanVideo" | |
transformer = HunyuanVideoTransformer3DModel.from_pretrained( | |
model, | |
subfolder="transformer", | |
device_map="balanced", | |
torch_dtype=torch.float16, # TODO: change to bfloat16 | |
) | |
print(f"transformer device: {transformer.device}") | |
# Cargar el pipeline | |
pipeline = HunyuanVideoPipeline.from_pretrained( | |
model, | |
transformer=transformer, | |
torch_dtype=torch.float16, # TODO: change to bfloat16 | |
device_map="balanced", | |
) | |
print(f"pipeline device: {pipeline.device}") | |
# TODO: pipeline.vae.enable_tiling() | |
# TODO: pipeline.to("cuda") | |
# Generar el video usando el pipeline | |
video = pipeline( | |
prompt=prompt, | |
height=height, | |
width=width, | |
num_frames=video_length, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
).frames[0] | |
# Guardar el video | |
save_path = os.path.join(os.getcwd(), "gradio_outputs") | |
os.makedirs(save_path, exist_ok=True) | |
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S") | |
video_path = f"{save_path}/{time_flag}_seed{seed}_{prompt[:100].replace('/','')}.mp4" | |
from diffusers.utils import export_to_video | |
export_to_video(video, video_path, fps=24) | |
print(f'Sample saved to: {video_path}') | |
return video_path | |
def create_demo(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# Hunyuan Video Generation") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt", value="A cat walks on the grass, realistic style.") | |
with gr.Row(): | |
resolution = gr.Dropdown( | |
choices=[ | |
# 720p | |
("1280x720 (16:9, 720p)", "1280x720"), | |
("720x1280 (9:16, 720p)", "720x1280"), | |
("1104x832 (4:3, 720p)", "1104x832"), | |
("832x1104 (3:4, 720p)", "832x1104"), | |
("960x960 (1:1, 720p)", "960x960"), | |
# 540p | |
("960x544 (16:9, 540p)", "960x544"), | |
("544x960 (9:16, 540p)", "544x960"), | |
("832x624 (4:3, 540p)", "832x624"), | |
("624x832 (3:4, 540p)", "624x832"), | |
("720x720 (1:1, 540p)", "720x720"), | |
], | |
value="1280x720", | |
label="Resolution" | |
) | |
video_length = gr.Dropdown( | |
label="Video Length", | |
choices=[ | |
("2s(65f)", 65), | |
("5s(129f)", 129), | |
], | |
value=129, | |
) | |
num_inference_steps = gr.Slider(1, 100, value=50, step=1, label="Number of Inference Steps") | |
show_advanced = gr.Checkbox(label="Show Advanced Options", value=False) | |
with gr.Row(visible=False) as advanced_row: | |
with gr.Column(): | |
seed = gr.Number(value=-1, label="Seed (-1 for random)") | |
guidance_scale = gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="Guidance Scale") | |
flow_shift = gr.Slider(0.0, 10.0, value=7.0, step=0.1, label="Flow Shift") | |
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale") | |
show_advanced.change(fn=lambda x: gr.Row(visible=x), inputs=[show_advanced], outputs=[advanced_row]) | |
generate_btn = gr.Button("Generate") | |
with gr.Column(): | |
output = gr.Video(label="Generated Video") | |
generate_btn.click( | |
fn=lambda *inputs: generate_video(*inputs), | |
inputs=[ | |
prompt, | |
resolution, | |
video_length, | |
seed, | |
num_inference_steps, | |
guidance_scale, | |
flow_shift, | |
embedded_guidance_scale | |
], | |
outputs=output | |
) | |
return demo | |
if __name__ == "__main__": | |
print("Starting Gradio server...") | |
demo = create_demo() | |
demo.launch() |