File size: 4,570 Bytes
599abe6
 
 
 
 
 
 
c079f49
599abe6
 
c079f49
599abe6
c079f49
599abe6
c079f49
599abe6
 
c079f49
599abe6
 
c079f49
 
 
 
 
 
599abe6
 
 
c079f49
599abe6
 
 
 
 
 
c079f49
599abe6
 
 
 
 
 
 
 
c079f49
599abe6
c079f49
599abe6
 
 
c079f49
599abe6
c079f49
599abe6
 
c079f49
 
599abe6
 
 
 
 
 
 
 
 
 
c079f49
599abe6
c079f49
 
 
 
 
 
599abe6
 
c079f49
599abe6
 
 
c079f49
599abe6
 
c079f49
599abe6
c079f49
599abe6
 
 
 
c079f49
599abe6
 
 
c079f49
599abe6
 
 
 
 
 
 
c079f49
 
599abe6
c079f49
599abe6
c079f49
 
599abe6
 
 
c079f49
599abe6
c079f49
 
 
 
 
599abe6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
import time
import random
import requests
from lumaai import LumaAI
import traceback

from lib.status_utils import load_messages, StatusTracker
from lib.image_utils import prepare_image
from lib.api_utils import get_camera_motions
from lib.ui_components import create_interface

def generate_video(api_key, prompt, camera_motion, loop_video, image=None, progress=gr.Progress()):
    if not api_key or not prompt:
        raise gr.Error("Please enter your LumaAI API key and prompt")
    
    try:
        progress(0, desc="Initializing LumaAI...")
        client = LumaAI(auth_token=api_key)
        
        # Create status tracker with progress object
        status_tracker = StatusTracker(
            progress=lambda x: progress(x),
            status_box=None
        )
        
        # Prepare generation parameters
        generation_params = {
            "prompt": f"{prompt} {camera_motion if camera_motion != 'None' else ''}",
            "loop": loop_video,
            "aspect_ratio": "1:1"  # Force square aspect ratio
        }

        # Handle image if provided
        if image is not None:
            try:
                progress(0.1, desc="Preparing image...")
                cdn_url = prepare_image(image, status_tracker)
                generation_params["keyframes"] = {
                    "frame0": {
                        "type": "image",
                        "url": cdn_url
                    }
                }
            except Exception as e:
                raise gr.Error("Failed to process the input image")

        progress(0.2, desc="Starting generation...")
        try:
            generation = client.generations.create(**generation_params)
        except Exception as e:
            raise gr.Error("Failed to start video generation. Please check your API key.")
        
        # Load and shuffle status messages for variety
        status_messages = load_messages()
        random.shuffle(status_messages)
        message_index = 0
        last_message_time = time.time()
        
        # Poll for completion
        start_time = time.time()
        last_status = None
        
        while True:
            try:
                generation_status = client.generations.get(generation.id)
                status = generation_status.state
                elapsed_time = time.time() - start_time
                current_time = time.time()
                
                # Update status message at random intervals between 2-8 seconds
                if current_time - last_message_time >= random.uniform(2, 8):
                    progress_val = min(0.2 + (elapsed_time/60), 0.8)  # Adjusted for 1-minute expectation
                    progress(progress_val, desc=status_messages[message_index % len(status_messages)])
                    message_index += 1
                    last_message_time = current_time
                
                if status == 'completed':
                    progress(0.9, desc="Generation completed!")
                    download_url = generation_status.assets.video
                    break
                elif status == 'failed':
                    raise gr.Error("Video generation failed")
                
                if elapsed_time > 300:
                    raise gr.Error("Generation timed out after 5 minutes")
                
                time.sleep(1)
                
            except Exception as e:
                print(f"Error during generation polling: {str(e)}")
                print(traceback.format_exc())
                time.sleep(1)
                continue
        
        # Download the video
        progress(0.95, desc="Downloading video...")
        try:
            response = requests.get(download_url, stream=True, timeout=30)
            response.raise_for_status()
            file_path = "output_video.mp4"
            with open(file_path, 'wb') as file:
                file.write(response.content)
            
            progress(1.0, desc="Video ready!")
            return file_path
        except Exception as e:
            raise gr.Error("Failed to download the generated video")
    
    except gr.Error as e:
        raise e
    except Exception as e:
        print(f"Error during generation: {str(e)}")
        print(traceback.format_exc())
        raise gr.Error("An unexpected error occurred")

# Create Gradio interface
app = create_interface(generate_video)

# For Hugging Face Spaces, we want to specify a smaller queue size
app.queue(max_size=5)

if __name__ == "__main__":
    app.launch()