Spaces:
Running
Running
""" | |
Preview tab for Video Model Studio UI | |
""" | |
import gradio as gr | |
import logging | |
from pathlib import Path | |
from typing import Dict, Any, List, Optional, Tuple | |
import time | |
from vms.utils import BaseTab | |
from vms.config import ( | |
MODEL_TYPES, DEFAULT_PROMPT_PREFIX | |
) | |
logger = logging.getLogger(__name__) | |
class PreviewTab(BaseTab): | |
"""Preview tab for testing trained models""" | |
def __init__(self, app_state): | |
super().__init__(app_state) | |
self.id = "preview_tab" | |
self.title = "5️⃣ Preview" | |
def create(self, parent=None) -> gr.TabItem: | |
"""Create the Preview tab UI components""" | |
with gr.TabItem(self.title, id=self.id) as tab: | |
with gr.Row(): | |
gr.Markdown("## Preview your model") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
self.components["prompt"] = gr.Textbox( | |
label="Prompt", | |
placeholder="Enter your prompt here...", | |
lines=3 | |
) | |
self.components["negative_prompt"] = gr.Textbox( | |
label="Negative Prompt", | |
placeholder="Enter negative prompt here...", | |
lines=3, | |
value="worst quality, low quality, blurry, jittery, distorted, ugly, deformed, disfigured, messy background" | |
) | |
self.components["prompt_prefix"] = gr.Textbox( | |
label="Global Prompt Prefix", | |
placeholder="Prefix to add to all prompts", | |
value=DEFAULT_PROMPT_PREFIX | |
) | |
with gr.Row(): | |
# Get the currently selected model type from training tab if possible | |
default_model = self.get_default_model_type() | |
# Make model_type read-only (disabled), as it must match what was trained | |
self.components["model_type"] = gr.Dropdown( | |
choices=list(MODEL_TYPES.keys()), | |
label="Model Type (from training)", | |
value=default_model, | |
interactive=False | |
) | |
# Add model variant selection based on model type | |
self.components["model_variant"] = gr.Dropdown( | |
label="Model Variant", | |
choices=self.get_variant_choices(default_model), | |
value=self.get_default_variant(default_model) | |
) | |
# Add image input for image-to-video models | |
self.components["conditioning_image"] = gr.Image( | |
label="Conditioning Image (for Image-to-Video models)", | |
type="filepath", | |
visible=False | |
) | |
with gr.Row(): | |
self.components["resolution_preset"] = gr.Dropdown( | |
choices=["480p", "720p"], | |
label="Resolution Preset", | |
value="480p" | |
) | |
with gr.Row(): | |
self.components["width"] = gr.Number( | |
label="Width", | |
value=832, | |
precision=0 | |
) | |
self.components["height"] = gr.Number( | |
label="Height", | |
value=480, | |
precision=0 | |
) | |
with gr.Row(): | |
self.components["num_frames"] = gr.Slider( | |
label="Number of Frames", | |
minimum=1, | |
maximum=257, | |
step=8, | |
value=49 | |
) | |
self.components["fps"] = gr.Slider( | |
label="FPS", | |
minimum=1, | |
maximum=60, | |
step=1, | |
value=16 | |
) | |
with gr.Row(): | |
self.components["guidance_scale"] = gr.Slider( | |
label="Guidance Scale", | |
minimum=1.0, | |
maximum=10.0, | |
step=0.1, | |
value=5.0 | |
) | |
self.components["flow_shift"] = gr.Slider( | |
label="Flow Shift", | |
minimum=0.0, | |
maximum=10.0, | |
step=0.1, | |
value=3.0 | |
) | |
with gr.Row(): | |
self.components["lora_weight"] = gr.Slider( | |
label="LoRA Weight", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.01, | |
value=0.7 | |
) | |
self.components["inference_steps"] = gr.Slider( | |
label="Inference Steps", | |
minimum=1, | |
maximum=100, | |
step=1, | |
value=30 | |
) | |
self.components["enable_cpu_offload"] = gr.Checkbox( | |
label="Enable Model CPU Offload (for low-VRAM GPUs)", | |
value=True | |
) | |
self.components["generate_btn"] = gr.Button( | |
"Generate Video", | |
variant="primary" | |
) | |
with gr.Column(scale=3): | |
self.components["preview_video"] = gr.Video( | |
label="Generated Video", | |
interactive=False | |
) | |
self.components["status"] = gr.Textbox( | |
label="Status", | |
interactive=False | |
) | |
with gr.Accordion("Log", open=True): | |
self.components["log"] = gr.TextArea( | |
label="Generation Log", | |
interactive=False, | |
lines=15 | |
) | |
return tab | |
def get_variant_choices(self, model_type: str) -> List[str]: | |
"""Get model variant choices based on model type""" | |
# Convert UI display name to internal name | |
internal_type = MODEL_TYPES.get(model_type) | |
if not internal_type: | |
return [] | |
# Get variants from preview service | |
variants = self.app.previewing.get_model_variants(internal_type) | |
if not variants: | |
return [] | |
# Format choices with display name and description | |
choices = [] | |
for model_id, info in variants.items(): | |
choices.append(f"{model_id} - {info.get('name', '')}") | |
return choices | |
def get_default_variant(self, model_type: str) -> str: | |
"""Get default model variant for the model type""" | |
choices = self.get_variant_choices(model_type) | |
if choices: | |
return choices[0] | |
return "" | |
def get_default_model_type(self) -> str: | |
"""Get the currently selected model type from training tab""" | |
try: | |
# Try to get the model type from UI state | |
ui_state = self.app.training.load_ui_state() | |
model_type = ui_state.get("model_type") | |
# Make sure it's a valid model type | |
if model_type in MODEL_TYPES: | |
return model_type | |
# If we couldn't get a valid model type, try to get it from the training tab directly | |
if hasattr(self.app, 'tabs') and 'train_tab' in self.app.tabs: | |
train_tab = self.app.tabs['train_tab'] | |
if hasattr(train_tab, 'components') and 'model_type' in train_tab.components: | |
train_model_type = train_tab.components['model_type'].value | |
if train_model_type in MODEL_TYPES: | |
return train_model_type | |
# Fallback to first model type | |
return list(MODEL_TYPES.keys())[0] | |
except Exception as e: | |
logger.warning(f"Failed to get default model type: {e}") | |
return list(MODEL_TYPES.keys())[0] | |
def extract_model_id(self, variant_choice: str) -> str: | |
"""Extract model ID from variant choice string""" | |
if " - " in variant_choice: | |
return variant_choice.split(" - ")[0].strip() | |
return variant_choice | |
def get_variant_type(self, model_type: str, model_variant: str) -> str: | |
"""Get the variant type (text-to-video or image-to-video)""" | |
# Convert UI display name to internal name | |
internal_type = MODEL_TYPES.get(model_type) | |
if not internal_type: | |
return "text-to-video" | |
# Extract model_id from variant choice | |
model_id = self.extract_model_id(model_variant) | |
# Get variants from preview service | |
variants = self.app.previewing.get_model_variants(internal_type) | |
variant_info = variants.get(model_id, {}) | |
# Return the variant type or default to text-to-video | |
return variant_info.get("type", "text-to-video") | |
def connect_events(self) -> None: | |
"""Connect event handlers to UI components""" | |
# Update resolution when preset changes | |
self.components["resolution_preset"].change( | |
fn=self.update_resolution, | |
inputs=[self.components["resolution_preset"]], | |
outputs=[ | |
self.components["width"], | |
self.components["height"], | |
self.components["flow_shift"] | |
] | |
) | |
# Update model_variant choices when model_type changes or tab is selected | |
if hasattr(self.app, 'tabs_component') and self.app.tabs_component is not None: | |
self.app.tabs_component.select( | |
fn=self.sync_model_type_and_variants, | |
inputs=[], | |
outputs=[ | |
self.components["model_type"], | |
self.components["model_variant"] | |
] | |
) | |
# Update variant-specific UI elements when variant changes | |
self.components["model_variant"].change( | |
fn=self.update_variant_ui, | |
inputs=[ | |
self.components["model_type"], | |
self.components["model_variant"] | |
], | |
outputs=[ | |
self.components["conditioning_image"] | |
] | |
) | |
# Load preview UI state when the tab is selected | |
if hasattr(self.app, 'tabs_component') and self.app.tabs_component is not None: | |
self.app.tabs_component.select( | |
fn=self.load_preview_state, | |
inputs=[], | |
outputs=[ | |
self.components["prompt"], | |
self.components["negative_prompt"], | |
self.components["prompt_prefix"], | |
self.components["width"], | |
self.components["height"], | |
self.components["num_frames"], | |
self.components["fps"], | |
self.components["guidance_scale"], | |
self.components["flow_shift"], | |
self.components["lora_weight"], | |
self.components["inference_steps"], | |
self.components["enable_cpu_offload"], | |
self.components["model_variant"] | |
] | |
) | |
# Save preview UI state when values change | |
for component_name in [ | |
"prompt", "negative_prompt", "prompt_prefix", "model_variant", "resolution_preset", | |
"width", "height", "num_frames", "fps", "guidance_scale", "flow_shift", | |
"lora_weight", "inference_steps", "enable_cpu_offload" | |
]: | |
if component_name in self.components: | |
self.components[component_name].change( | |
fn=self.save_preview_state_value, | |
inputs=[self.components[component_name]], | |
outputs=[] | |
) | |
# Generate button click | |
self.components["generate_btn"].click( | |
fn=self.generate_video, | |
inputs=[ | |
self.components["model_type"], | |
self.components["model_variant"], | |
self.components["prompt"], | |
self.components["negative_prompt"], | |
self.components["prompt_prefix"], | |
self.components["width"], | |
self.components["height"], | |
self.components["num_frames"], | |
self.components["guidance_scale"], | |
self.components["flow_shift"], | |
self.components["lora_weight"], | |
self.components["inference_steps"], | |
self.components["enable_cpu_offload"], | |
self.components["fps"], | |
self.components["conditioning_image"] | |
], | |
outputs=[ | |
self.components["preview_video"], | |
self.components["status"], | |
self.components["log"] | |
] | |
) | |
def update_variant_ui(self, model_type: str, model_variant: str) -> Dict[str, Any]: | |
"""Update UI based on the selected model variant""" | |
variant_type = self.get_variant_type(model_type, model_variant) | |
# Show conditioning image input only for image-to-video models | |
show_conditioning_image = variant_type == "image-to-video" | |
return { | |
self.components["conditioning_image"]: gr.Image(visible=show_conditioning_image) | |
} | |
def sync_model_type_and_variants(self) -> Tuple[str, str]: | |
"""Sync model type with training tab when preview tab is selected and update variant choices""" | |
model_type = self.get_default_model_type() | |
model_variant = self.get_default_variant(model_type) | |
return model_type, model_variant | |
def update_resolution(self, preset: str) -> Tuple[int, int, float]: | |
"""Update resolution and flow shift based on preset""" | |
if preset == "480p": | |
return 832, 480, 3.0 | |
elif preset == "720p": | |
return 1280, 720, 5.0 | |
else: | |
return 832, 480, 3.0 | |
def load_preview_state(self) -> Tuple: | |
"""Load saved preview UI state""" | |
# Try to get the saved state | |
try: | |
state = self.app.training.load_ui_state() | |
preview_state = state.get("preview", {}) | |
# Get model type (can't be changed in UI) | |
model_type = self.get_default_model_type() | |
# If model_variant not in choices for current model_type, use default | |
model_variant = preview_state.get("model_variant", "") | |
variant_choices = self.get_variant_choices(model_type) | |
if model_variant not in variant_choices and variant_choices: | |
model_variant = variant_choices[0] | |
return ( | |
preview_state.get("prompt", ""), | |
preview_state.get("negative_prompt", "worst quality, low quality, blurry, jittery, distorted, ugly, deformed, disfigured, messy background"), | |
preview_state.get("prompt_prefix", DEFAULT_PROMPT_PREFIX), | |
preview_state.get("width", 832), | |
preview_state.get("height", 480), | |
preview_state.get("num_frames", 49), | |
preview_state.get("fps", 16), | |
preview_state.get("guidance_scale", 5.0), | |
preview_state.get("flow_shift", 3.0), | |
preview_state.get("lora_weight", 0.7), | |
preview_state.get("inference_steps", 30), | |
preview_state.get("enable_cpu_offload", True), | |
model_variant | |
) | |
except Exception as e: | |
logger.error(f"Error loading preview state: {e}") | |
# Return defaults if loading fails | |
return ( | |
"", | |
"worst quality, low quality, blurry, jittery, distorted, ugly, deformed, disfigured, messy background", | |
DEFAULT_PROMPT_PREFIX, | |
832, 480, 49, 16, 5.0, 3.0, 0.7, 30, True, | |
self.get_default_variant(self.get_default_model_type()) | |
) | |
def save_preview_state_value(self, value: Any) -> None: | |
"""Save an individual preview state value""" | |
try: | |
# Get the component name from the event context | |
import inspect | |
frame = inspect.currentframe() | |
frame = inspect.getouterframes(frame)[1] | |
event_context = frame.frame.f_locals | |
component = event_context.get('component') | |
if component is None: | |
return | |
# Find the component name | |
component_name = None | |
for name, comp in self.components.items(): | |
if comp == component: | |
component_name = name | |
break | |
if component_name is None: | |
return | |
# Load current state | |
state = self.app.training.load_ui_state() | |
if "preview" not in state: | |
state["preview"] = {} | |
# Update the value | |
state["preview"][component_name] = value | |
# Save state | |
self.app.training.save_ui_state(state) | |
except Exception as e: | |
logger.error(f"Error saving preview state: {e}") | |
def generate_video( | |
self, | |
model_type: str, | |
model_variant: str, | |
prompt: str, | |
negative_prompt: str, | |
prompt_prefix: str, | |
width: int, | |
height: int, | |
num_frames: int, | |
guidance_scale: float, | |
flow_shift: float, | |
lora_weight: float, | |
inference_steps: int, | |
enable_cpu_offload: bool, | |
fps: int, | |
conditioning_image: Optional[str] = None | |
) -> Tuple[Optional[str], str, str]: | |
"""Handler for generate button click, delegates to preview service""" | |
# Save all the parameters to preview state before generating | |
try: | |
state = self.app.training.load_ui_state() | |
if "preview" not in state: | |
state["preview"] = {} | |
# Extract model ID from variant choice | |
model_variant_id = self.extract_model_id(model_variant) | |
# Update all values | |
preview_state = { | |
"prompt": prompt, | |
"negative_prompt": negative_prompt, | |
"prompt_prefix": prompt_prefix, | |
"model_type": model_type, | |
"model_variant": model_variant, | |
"width": width, | |
"height": height, | |
"num_frames": num_frames, | |
"fps": fps, | |
"guidance_scale": guidance_scale, | |
"flow_shift": flow_shift, | |
"lora_weight": lora_weight, | |
"inference_steps": inference_steps, | |
"enable_cpu_offload": enable_cpu_offload | |
} | |
state["preview"] = preview_state | |
self.app.training.save_ui_state(state) | |
except Exception as e: | |
logger.error(f"Error saving preview state before generation: {e}") | |
# Clear the log display at the start to make room for new logs | |
# Yield and sleep briefly to allow UI update | |
yield None, "Starting generation...", "" | |
time.sleep(0.1) | |
# Extract model ID from variant choice string | |
model_variant_id = self.extract_model_id(model_variant) | |
# Use streaming updates to provide real-time feedback during generation | |
def generate_with_updates(): | |
# Initial UI update | |
yield None, "Initializing generation...", "Starting video generation process..." | |
# Start actual generation | |
result = self.app.previewing.generate_video( | |
model_type=model_type, | |
model_variant=model_variant_id, | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
prompt_prefix=prompt_prefix, | |
width=width, | |
height=height, | |
num_frames=num_frames, | |
guidance_scale=guidance_scale, | |
flow_shift=flow_shift, | |
lora_weight=lora_weight, | |
inference_steps=inference_steps, | |
enable_cpu_offload=enable_cpu_offload, | |
fps=fps, | |
conditioning_image=conditioning_image | |
) | |
# Return final result | |
return result | |
# Return the generator for streaming updates | |
return generate_with_updates() |