remade-effects / workflow_handler.py
alex-remade's picture
working with progress bar
d948455
raw
history blame
4.25 kB
import json
from pathlib import Path
from video_config import MODEL_FRAME_RATES, calculate_frames
class WanVideoWorkflow:
def __init__(self, supabase_client, config_path="config.json", workflow_path="wani2v.json"):
# Add debug prints and error handling
try:
workflow_path = Path(workflow_path)
config_path = Path(config_path)
print(f"Loading workflow from: {workflow_path.absolute()}")
print(f"File exists: {workflow_path.exists()}")
print(f"File size: {workflow_path.stat().st_size}")
# Load config and workflow
with open(config_path, 'r', encoding='utf-8') as f:
self.config = json.load(f)
with open(workflow_path, 'r', encoding='utf-8') as f:
content = f.read()
print(f"Raw content length: {len(content)}")
if not content:
raise ValueError(f"Workflow file is empty: {workflow_path}")
self.workflow = json.loads(content)
self.supabase = supabase_client
# Get node mappings
self.model_config = next(
(model for model in self.config["models"] if model["id"] == "wanvideo_itv"),
None
)
if not self.model_config:
raise ValueError("WanVideo I2V model config not found")
self.nodes = self.model_config["nodes"]
except Exception as e:
print(f"Error initializing WanVideoWorkflow: {str(e)}")
raise
async def get_lora_path(self, lora_id):
"""Get LoRA path from Supabase training_jobs table"""
response = self.supabase.table('training_jobs').select(
'name, visible, config'
).eq('id', lora_id).execute()
if not response.data:
raise ValueError(f"LoRA with ID {lora_id} not found")
lora_data = response.data[0]
visible_epoch = lora_data['visible']
if visible_epoch is None:
raise ValueError(f"LoRA {lora_id} has no visible epoch")
return f"{lora_id}/run/epoch{visible_epoch}/adapter_model.safetensors"
def update_prompt(self, prompt):
"""Update the prompt node"""
prompt_node = self.workflow[self.nodes["prompt"]]
if not prompt_node.get("inputs"):
raise ValueError("Invalid prompt node structure")
prompt_node["inputs"]["positive_prompt"] = prompt
def update_input_image(self, image_path):
"""Update the input image node"""
image_node = self.workflow[self.nodes["image"]]
if not image_node.get("inputs"):
raise ValueError("Invalid image node structure")
image_node["inputs"]["image"] = Path(image_path).name
async def update_lora(self, lora_config):
"""Update the LoRA node"""
lora_node = self.workflow[self.nodes["lora"]]
if not lora_node.get("inputs"):
raise ValueError("Invalid LoRA node structure")
# Get LoRA path from Supabase
lora_path = await self.get_lora_path(lora_config["id"])
lora_node["inputs"]["lora"] = lora_path
lora_node["inputs"]["strength"] = 1.0
def update_length(self, duration):
"""Update video length (number of frames)"""
dimensions_node = self.workflow[self.nodes["dimensions"]]
if not dimensions_node.get("inputs"):
raise ValueError("Invalid dimensions node structure")
frame_rate = MODEL_FRAME_RATES["wanvideo"]
num_frames = calculate_frames(duration, frame_rate)
dimensions_node["inputs"]["num_frames"] = num_frames
def update_output_name(self, generation_id):
"""Update the output filename"""
video_combine_node = self.workflow[self.nodes["videoCombine"]]
if not video_combine_node.get("inputs"):
raise ValueError("Invalid video combine node structure")
video_combine_node["inputs"]["filename_prefix"] = generation_id
def get_workflow(self):
"""Return the complete workflow"""
return self.workflow