Spaces:
Running
Running
import json | |
from pathlib import Path | |
from video_config import MODEL_FRAME_RATES, calculate_frames | |
class WanVideoWorkflow: | |
def __init__(self, supabase_client, config_path="config.json", workflow_path="wani2v.json"): | |
# Add debug prints and error handling | |
try: | |
workflow_path = Path(workflow_path) | |
config_path = Path(config_path) | |
print(f"Loading workflow from: {workflow_path.absolute()}") | |
print(f"File exists: {workflow_path.exists()}") | |
print(f"File size: {workflow_path.stat().st_size}") | |
# Load config and workflow | |
with open(config_path, 'r', encoding='utf-8') as f: | |
self.config = json.load(f) | |
with open(workflow_path, 'r', encoding='utf-8') as f: | |
content = f.read() | |
print(f"Raw content length: {len(content)}") | |
if not content: | |
raise ValueError(f"Workflow file is empty: {workflow_path}") | |
self.workflow = json.loads(content) | |
self.supabase = supabase_client | |
# Get node mappings | |
self.model_config = next( | |
(model for model in self.config["models"] if model["id"] == "wanvideo_itv"), | |
None | |
) | |
if not self.model_config: | |
raise ValueError("WanVideo I2V model config not found") | |
self.nodes = self.model_config["nodes"] | |
except Exception as e: | |
print(f"Error initializing WanVideoWorkflow: {str(e)}") | |
raise | |
async def get_lora_path(self, lora_id): | |
"""Get LoRA path from Supabase training_jobs table""" | |
response = self.supabase.table('training_jobs').select( | |
'name, visible, config' | |
).eq('id', lora_id).execute() | |
if not response.data: | |
raise ValueError(f"LoRA with ID {lora_id} not found") | |
lora_data = response.data[0] | |
visible_epoch = lora_data['visible'] | |
if visible_epoch is None: | |
raise ValueError(f"LoRA {lora_id} has no visible epoch") | |
return f"{lora_id}/run/epoch{visible_epoch}/adapter_model.safetensors" | |
def update_prompt(self, prompt): | |
"""Update the prompt node""" | |
prompt_node = self.workflow[self.nodes["prompt"]] | |
if not prompt_node.get("inputs"): | |
raise ValueError("Invalid prompt node structure") | |
prompt_node["inputs"]["positive_prompt"] = prompt | |
def update_input_image(self, image_path): | |
"""Update the input image node""" | |
image_node = self.workflow[self.nodes["image"]] | |
if not image_node.get("inputs"): | |
raise ValueError("Invalid image node structure") | |
image_node["inputs"]["image"] = Path(image_path).name | |
async def update_lora(self, lora_config): | |
"""Update the LoRA node""" | |
lora_node = self.workflow[self.nodes["lora"]] | |
if not lora_node.get("inputs"): | |
raise ValueError("Invalid LoRA node structure") | |
# Get LoRA path from Supabase | |
lora_path = await self.get_lora_path(lora_config["id"]) | |
lora_node["inputs"]["lora"] = lora_path | |
lora_node["inputs"]["strength"] = 1.0 | |
def update_length(self, duration): | |
"""Update video length (number of frames)""" | |
dimensions_node = self.workflow[self.nodes["dimensions"]] | |
if not dimensions_node.get("inputs"): | |
raise ValueError("Invalid dimensions node structure") | |
frame_rate = MODEL_FRAME_RATES["wanvideo"] | |
num_frames = calculate_frames(duration, frame_rate) | |
dimensions_node["inputs"]["num_frames"] = num_frames | |
def update_output_name(self, generation_id): | |
"""Update the output filename""" | |
video_combine_node = self.workflow[self.nodes["videoCombine"]] | |
if not video_combine_node.get("inputs"): | |
raise ValueError("Invalid video combine node structure") | |
video_combine_node["inputs"]["filename_prefix"] = generation_id | |
def get_workflow(self): | |
"""Return the complete workflow""" | |
return self.workflow |