File size: 4,251 Bytes
d948455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import json
from pathlib import Path
from video_config import MODEL_FRAME_RATES, calculate_frames

class WanVideoWorkflow:
    def __init__(self, supabase_client, config_path="config.json", workflow_path="wani2v.json"):
        # Add debug prints and error handling
        try:
            workflow_path = Path(workflow_path)
            config_path = Path(config_path)
            
            print(f"Loading workflow from: {workflow_path.absolute()}")
            print(f"File exists: {workflow_path.exists()}")
            print(f"File size: {workflow_path.stat().st_size}")
            
            # Load config and workflow
            with open(config_path, 'r', encoding='utf-8') as f:
                self.config = json.load(f)
            
            with open(workflow_path, 'r', encoding='utf-8') as f:
                content = f.read()
                print(f"Raw content length: {len(content)}")
                if not content:
                    raise ValueError(f"Workflow file is empty: {workflow_path}")
                self.workflow = json.loads(content)
            
            self.supabase = supabase_client
            
            # Get node mappings
            self.model_config = next(
                (model for model in self.config["models"] if model["id"] == "wanvideo_itv"),
                None
            )
            if not self.model_config:
                raise ValueError("WanVideo I2V model config not found")
            
            self.nodes = self.model_config["nodes"]
            
        except Exception as e:
            print(f"Error initializing WanVideoWorkflow: {str(e)}")
            raise

    async def get_lora_path(self, lora_id):
        """Get LoRA path from Supabase training_jobs table"""
        response = self.supabase.table('training_jobs').select(
            'name, visible, config'
        ).eq('id', lora_id).execute()
        
        if not response.data:
            raise ValueError(f"LoRA with ID {lora_id} not found")
            
        lora_data = response.data[0]
        visible_epoch = lora_data['visible']
        
        if visible_epoch is None:
            raise ValueError(f"LoRA {lora_id} has no visible epoch")
            
        return f"{lora_id}/run/epoch{visible_epoch}/adapter_model.safetensors"

    def update_prompt(self, prompt):
        """Update the prompt node"""
        prompt_node = self.workflow[self.nodes["prompt"]]
        if not prompt_node.get("inputs"):
            raise ValueError("Invalid prompt node structure")
        prompt_node["inputs"]["positive_prompt"] = prompt

    def update_input_image(self, image_path):
        """Update the input image node"""
        image_node = self.workflow[self.nodes["image"]]
        if not image_node.get("inputs"):
            raise ValueError("Invalid image node structure")
        image_node["inputs"]["image"] = Path(image_path).name

    async def update_lora(self, lora_config):
        """Update the LoRA node"""
        lora_node = self.workflow[self.nodes["lora"]]
        if not lora_node.get("inputs"):
            raise ValueError("Invalid LoRA node structure")
        
        # Get LoRA path from Supabase
        lora_path = await self.get_lora_path(lora_config["id"])
        lora_node["inputs"]["lora"] = lora_path
        lora_node["inputs"]["strength"] = 1.0

    def update_length(self, duration):
        """Update video length (number of frames)"""
        dimensions_node = self.workflow[self.nodes["dimensions"]]
        if not dimensions_node.get("inputs"):
            raise ValueError("Invalid dimensions node structure")
        
        frame_rate = MODEL_FRAME_RATES["wanvideo"]
        num_frames = calculate_frames(duration, frame_rate)
        dimensions_node["inputs"]["num_frames"] = num_frames

    def update_output_name(self, generation_id):
        """Update the output filename"""
        video_combine_node = self.workflow[self.nodes["videoCombine"]]
        if not video_combine_node.get("inputs"):
            raise ValueError("Invalid video combine node structure")
        video_combine_node["inputs"]["filename_prefix"] = generation_id

    def get_workflow(self):
        """Return the complete workflow"""
        return self.workflow