File size: 3,962 Bytes
291e480
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
import asyncio
import fal_client
from dotenv import load_dotenv
import os
from pathlib import Path
import time

load_dotenv()
os.environ["FAL_KEY"] = os.getenv("FAL_API_KEY")

async def generate_paris_images(image1_path: str, image2_path: str, woman_prompt: str, man_prompt: str, batch_size: int, progress=gr.Progress()):
    start_time = time.time()
    print("Progress: 5% - Starting Paris image generation...")
    progress(0.05, desc="Starting Paris image generation...")
    
    # Upload all images in parallel
    upload_tasks = [
        fal_client.upload_file_async(str(image1_path)),
        fal_client.upload_file_async(str(image2_path)),
        fal_client.upload_file_async("test_images/woman.png"),
        fal_client.upload_file_async("test_images/man.png"),
        fal_client.upload_file_async("test_images/clipspace-mask-4736783.png"),
        fal_client.upload_file_async("test_images/clipspace-mask-4722992.png")
    ]
    
    [image1_url, image2_url, woman_img, man_img, mask1_img, mask2_img] = await asyncio.gather(*upload_tasks)
    
    print("Progress: 40% - Uploaded all images")
    progress(0.4, desc="Uploaded all images")

    handler = await fal_client.submit_async(
        "comfy/LVE/paris-couple",
        arguments={
            "loadimage_1": image1_url,
            "loadimage_2": image2_url,
            "loadimage_3": woman_img,
            "loadimage_4": mask1_img,
            "loadimage_5": mask2_img,
            "loadimage_6": man_img,
            "woman_prompt": woman_prompt,
            "man_prompt": man_prompt,
            "batch_size": batch_size
        }
    )

    print("Progress: 60% - Processing images...")
    progress(0.6, desc="Processing images...")
    
    result = await handler.get()
    print(result)

    end_time = time.time()
    processing_time = end_time - start_time
    print(f"Progress: 100% - Generation completed in {processing_time:.2f} seconds")
    progress(1.0, desc=f"Generation completed in {processing_time:.2f} seconds")
    
    # Return all generated image URLs and processing time
    # Get the first key from outputs dynamically
    return (
        [img["url"] for img in result["outputs"][next(iter(result["outputs"]))]["images"]] if "outputs" in result and result["outputs"] else [],
        f"Processing time: {processing_time:.2f} seconds"
    )

with gr.Blocks() as demo:
    with gr.Row():
        image1_input = gr.Image(label="Upload Woman Image", type="filepath", value="test_images/user3-f.jpg")
        image2_input = gr.Image(label="Upload Man Image", type="filepath", value="test_images/user3.jpg")
    
    with gr.Row():
        woman_prompt = gr.Textbox(
            label="Woman Prompt",
            value="Close-up, portrait photo, a woman, Paris nighttime romance scene, wearing an elegant black dress with a shawl, standing beneath the same canopy of twinkling lights along the Champs-Élysées, the Eiffel Tower glowing bright in the distance, soft mist rising from the street, looking at the camera."
        )
        man_prompt = gr.Textbox(
            label="Man Prompt",
            value="Close-up, portrait photo, a man, Paris nighttime romance scene, wearing a tailored suit with a crisp white shirt, standing beneath a canopy of twinkling lights along the Champs-Élysées, the Eiffel Tower glowing bright in the distance, soft mist rising from the street, looking at the camera."
        )
    
    batch_size = gr.Slider(minimum=1, maximum=8, value=4, step=1, label="Batch Size")
    
    generate_btn = gr.Button("Generate")
    image_output = gr.Gallery(label="Generated Image")
    time_output = gr.Textbox(label="Processing Time")
    
    generate_btn.click(
        fn=generate_paris_images,
        inputs=[image1_input, image2_input, woman_prompt, man_prompt, batch_size],
        outputs=[image_output, time_output]
    )

if __name__ == "__main__":
    print("Starting Gradio interface...")
    demo.launch()