Huanyan-Studio / demo_app.py
Sergidev's picture
upload 7
feae090
import spaces
import gc
import gradio as gr
import numpy as np
import os
from pathlib import Path
from diffusers import GGUFQuantizationConfig, HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video
from huggingface_hub import snapshot_download
import torch
from PIL import Image
# Configuration
gc.collect()
torch.cuda.empty_cache()
torch.set_grad_enabled(False)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Load base model
model_id = "hunyuanvideo-community/HunyuanVideo"
base_path = f"/home/user/app/{model_id}"
os.makedirs(base_path, exist_ok=True)
snapshot_download(repo_id=model_id, local_dir=base_path)
# Load transformer
ckp_path = Path(base_path)
gguf_filename = "hunyuan-video-t2v-720p-Q4_0.gguf"
transformer_path = f"https://huggingface.co/city96/HunyuanVideo-gguf/blob/main/{gguf_filename}"
transformer = HunyuanVideoTransformer3DModel.from_single_file(
transformer_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
).to('cuda')
# Initialize pipeline
pipe = HunyuanVideoPipeline.from_pretrained(
ckp_path,
transformer=transformer,
torch_dtype=torch.float16
).to("cuda")
# Configure VAE
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()
pipe.vae.eval()
# Available LORAs with display names
LORA_CHOICES = [
("stripe_v2.safetensors", "Stripe Style"),
("Top_Off.safetensors", "Top Off Effect"),
("huanyan_helper.safetensors", "Hunyuan Helper"),
("huanyan_helper_alpha.safetensors", "Hunyuan Alpha"),
("hunyuan-t-solo-v1.0.safetensors", "Solo Animation")
]
# Load all LORAs with hunyuanvideo-lora adapter
for weight_name, display_name in LORA_CHOICES:
pipe.load_lora_weights(
"Sergidev/TTV4ME",
weight_name=weight_name,
adapter_name=display_name.replace(" ", "_").lower(),
token=os.environ.get("HF_TOKEN")
)
# Memory cleanup
gc.collect()
torch.cuda.empty_cache()
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
@spaces.GPU(duration=300)
def generate(
prompt,
image_input,
height,
width,
num_frames,
num_inference_steps,
seed_value,
fps,
selected_loras,
lora_weights,
progress=gr.Progress(track_tqdm=True)
):
# Validate image resolution
if image_input is not None:
img = Image.open(image_input)
if img.size != (width, height):
raise gr.Error(f"Image resolution {img.size} must match video resolution ({width}x{height})")
# Configure LORAs
active_adapters = [lora[1].replace(" ", "_").lower() for lora in LORA_CHOICES if lora[1] in selected_loras]
weights = [float(lora_weights[selected_loras.index(lora[1])]) for lora in LORA_CHOICES if lora[1] in selected_loras]
pipe.set_adapters(active_adapters, weights)
with torch.cuda.device(0):
if seed_value == -1:
seed_value = torch.randint(0, MAX_SEED, (1,)).item()
generator = torch.Generator('cuda').manual_seed(seed_value)
with torch.amp.autocast_mode.autocast('cuda', dtype=torch.bfloat16), torch.inference_mode(), torch.no_grad():
# Use image input if provided, else use text prompt
if image_input:
output = pipe(
image=Image.open(image_input).convert("RGB"),
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=num_inference_steps,
generator=generator,
).frames[0]
else:
output = pipe(
prompt=prompt,
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=num_inference_steps,
generator=generator,
).frames[0]
output_path = "output.mp4"
export_to_video(output, output_path, fps=fps)
torch.cuda.empty_cache()
gc.collect()
return output_path
def apply_preset(preset_name, *current_values):
if preset_name == "Higher Resolution":
return [608, 448, 24, 29, 12]
elif preset_name == "More Frames":
return [512, 320, 42, 27, 14]
return current_values
css = """
#col-container {
margin: 0 auto;
max-width: 850px;
}
.dark-theme {
background-color: #1f1f1f;
color: #ffffff;
}
.container {
margin: 0 auto;
padding: 20px;
border-radius: 10px;
background-color: #2d2d2d;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.title {
text-align: center;
margin-bottom: 1em;
color: #ffffff;
}
.description {
text-align: center;
margin-bottom: 2em;
color: #cccccc;
font-size: 0.95em;
line-height: 1.5;
}
.prompt-container {
background-color: #363636;
padding: 15px;
border-radius: 8px;
margin-bottom: 1em;
width: 100%;
}
.prompt-textbox {
min-height: 80px !important;
}
.preset-buttons {
display: flex;
gap: 10px;
justify-content: center;
margin-bottom: 1em;
}
.support-text {
text-align: center;
margin-top: 1em;
color: #cccccc;
font-size: 0.9em;
}
a {
color: #00a7e1;
text-decoration: none;
}
a:hover {
text-decoration: underline;
}
.lora-sliders {
margin-top: 15px;
border-top: 1px solid #444;
padding-top: 15px;
}
"""
with gr.Blocks(css=css, theme="dark") as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# 🎬 Hunyuan Studio", elem_classes=["title"])
gr.Markdown(
"""Generate videos from text or images using multiple LoRA adapters.
Requires matching resolution between input image and output settings.""",
elem_classes=["description"]
)
with gr.Column(elem_classes=["prompt-container"]):
prompt = gr.Textbox(
label="Prompt",
placeholder="Enter text prompt or upload image below",
show_label=False,
elem_classes=["prompt-textbox"],
lines=3
)
image_input = gr.Image(type="filepath", label="Upload Image (Optional)")
with gr.Row():
run_button = gr.Button("🎨 Generate", variant="primary", size="lg")
with gr.Row(elem_classes=["preset-buttons"]):
preset_high_res = gr.Button("📺 Higher Resolution Preset")
preset_more_frames = gr.Button("🎞️ More Frames Preset")
with gr.Row():
result = gr.Video(label="Generated Video")
with gr.Accordion("⚙️ Advanced Settings", open=False):
seed = gr.Slider(
label="Seed (-1 for random)",
minimum=-1,
maximum=MAX_SEED,
step=1,
value=-1,
)
with gr.Row():
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=16,
value=608,
)
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=16,
value=448,
)
with gr.Row():
num_frames = gr.Slider(
label="Number of frames",
minimum=1.0,
maximum=257.0,
step=1,
value=24,
)
num_inference_steps = gr.Slider(
label="Inference steps",
minimum=1,
maximum=50,
step=1,
value=29,
)
fps = gr.Slider(
label="Frames per second",
minimum=1,
maximum=60,
step=1,
value=12,
)
with gr.Column(elem_classes=["lora-sliders"]):
gr.Markdown("### LoRA Adapters")
lora_checkboxes = gr.CheckboxGroup(
label="Select LoRAs",
choices=[display for (_, display) in LORA_CHOICES],
value=["Stripe Style", "Top Off Effect"]
)
lora_weight_sliders = []
for _, display_name in LORA_CHOICES:
lora_weight_sliders.append(
gr.Slider(
label=f"{display_name} Weight",
minimum=0.0,
maximum=1.0,
value=0.9 if "Stripe" in display_name else 0.8,
visible=False
)
)
# Event handling
run_button.click(
fn=generate,
inputs=[prompt, image_input, height, width, num_frames,
num_inference_steps, seed, fps, lora_checkboxes, lora_weight_sliders],
outputs=[result],
)
# Preset button handlers
preset_high_res.click(
fn=lambda: apply_preset("Higher Resolution"),
outputs=[height, width, num_frames, num_inference_steps, fps]
)
preset_more_frames.click(
fn=lambda: apply_preset("More Frames"),
outputs=[height, width, num_frames, num_inference_steps, fps]
)
# Show/hide LORA weight sliders based on checkbox selection
def toggle_lora_sliders(selected_loras):
updates = []
for lora in LORA_CHOICES:
updates.append(gr.update(visible=lora[1] in selected_loras))
return updates
lora_checkboxes.change(
fn=toggle_lora_sliders,
inputs=lora_checkboxes,
outputs=lora_weight_sliders
)