Couple-Product / app.py
Zeph27's picture
Add image files using Git LFS
61f0b81
import gradio as gr
import asyncio
import fal_client
from dotenv import load_dotenv
import os
from pathlib import Path
import time
import json
load_dotenv()
os.environ["FAL_KEY"] = os.getenv("FAL_API_KEY")
async def generate_paris_images(product_name: str, image1_path: str, image2_path: str, woman_prompt: str, man_prompt: str, girl_name: str, girl_hair_length: str, girl_hair_style: str, girl_hair_color: str, boy_name: str, boy_hair_length: str, boy_hair_style: str, boy_hair_color: str, batch_size: int, progress=gr.Progress()):
start_time = time.time()
print("Progress: 5% - Starting Paris image generation...")
progress(0.05, desc="Starting Paris image generation...")
# Upload all images in parallel
upload_tasks = [
fal_client.upload_file_async(str(image1_path)),
fal_client.upload_file_async(str(image2_path)),
fal_client.upload_file_async("template/man_pose.png"),
fal_client.upload_file_async("template/woman_pose.png"),
fal_client.upload_file_async("template/woman_clip_mask.png"),
fal_client.upload_file_async("template/man_clip_mask.png")
]
[image1_url, image2_url, man_pose_img, woman_pose_img, woman_clip_mask, man_clip_mask] = await asyncio.gather(*upload_tasks)
print("Progress: 40% - Uploaded all images")
progress(0.4, desc="Uploaded all images")
# Replace {hair_feature} placeholders with user hair descriptions
woman_hair_desc = f"{girl_hair_length} {girl_hair_style} {girl_hair_color} hair,"
print(f"Final woman hair description: {woman_hair_desc}")
# Handle bald case for man's hair description
if boy_hair_length == "Bald":
man_hair_desc = "bald,"
else:
man_hair_desc = f"{boy_hair_length} {boy_hair_style} {boy_hair_color} hair,"
print(f"Final man hair description: {man_hair_desc}")
woman_prompt = woman_prompt.replace("{hair_feature}", woman_hair_desc)
man_prompt = man_prompt.replace("{hair_feature}", man_hair_desc)
print(f"Final woman prompt: {woman_prompt}")
print(f"Final man prompt: {man_prompt}")
handler = await fal_client.submit_async(
"comfy/LVE/paris-couple",
arguments={
"loadimage_1": image1_url,
"loadimage_2": image2_url,
"loadimage_3": woman_pose_img,
"loadimage_4": woman_clip_mask,
"loadimage_5": man_clip_mask,
"loadimage_6": man_pose_img,
"woman_prompt": woman_prompt,
"man_prompt": man_prompt,
"girl_name": girl_name,
"boy_name": boy_name,
"batch_size": batch_size
}
)
print("Progress: 60% - Processing images...")
progress(0.6, desc="Processing images...")
result = await handler.get()
print(result)
end_time = time.time()
processing_time = end_time - start_time
print(f"Progress: 100% - Generation completed in {processing_time:.2f} seconds")
progress(1.0, desc=f"Generation completed in {processing_time:.2f} seconds")
# Fix the URL extraction logic
image_215 = []
image_818 = []
if "outputs" in result:
if "215" in result["outputs"]:
image_215 = [img["url"] for img in result["outputs"]["215"]["images"]]
if "818" in result["outputs"]:
image_818 = [img["url"] for img in result["outputs"]["818"]["images"]]
print(f"Image 215: {image_215}")
print(f"Image 818: {image_818}")
# Return all generated image URLs and processing time
# Get the first key from outputs dynamically
return (
image_215,
image_818,
f"Processing time: {processing_time:.2f} seconds"
)
def change_product_preview(product_name):
# Load prompts from JSON file
with open('prompt.json', 'r') as f:
prompts = json.load(f)
# Find the matching prompt data
prompt_data = next((item for item in prompts if item['title'] == product_name), None)
if prompt_data:
return (
f"thumbnail/{product_name}.png",
prompt_data['woman'],
prompt_data['man']
)
return None, "", ""
with gr.Blocks() as demo:
with gr.Row():
product_name = gr.Dropdown(label="Product Name", choices=["Winter", "Classy", "Night Out", "Romantic"], value="Winter")
product_preview = gr.Image(label="Product Preview", type="filepath", value="thumbnail/Winter.png", height=500, width=500)
with gr.Row():
image1_input = gr.Image(label="Upload Woman Image", type="filepath", value="user3-f.jpg")
image2_input = gr.Image(label="Upload Man Image", type="filepath", value="user3-m.jpg")
with gr.Row():
with gr.Column():
woman_prompt = gr.Textbox(
label="Woman Prompt",
value="Close-up, portrait photo, a woman, {hair_feature} wearing a cream-colored wool coat, chunky knit scarf, and matching earmuffs, standing on the same snow-dusted cobblestone street, illuminated Eiffel Tower glowing golden in the background, snowflakes sparkling in the warm streetlight glow, looking at camera with gentle smile."
)
girl_name = gr.Textbox(
label="Girl Name",
value="julie delpy"
)
girl_hair_length = gr.Dropdown(label="Girl Hair Length", choices=["Short", "Medium", "Long"], value="Long")
girl_hair_style = gr.Dropdown(label="Girl Hair Style", choices=["Straight", "Wavy", "Curly"], value="Straight")
girl_hair_color = gr.Dropdown(label="Girl Hair Color", choices=["Blonde", "Brown", "Black", "Brunette", "Redhead", "Bronde"], value="Bronde")
with gr.Column():
man_prompt = gr.Textbox(
label="Man Prompt",
value="Close-up, portrait photo, a man, {hair_feature} wearing a dark navy wool peacoat, cashmere scarf, and leather gloves, standing on a snow-dusted cobblestone street, illuminated Eiffel Tower in the background glowing golden against the night sky, gentle snowflakes catching the warm glow of vintage streetlamps, looking at camera with confident expression."
)
boy_name = gr.Textbox(
label="Boy Name",
value="ethan hawke"
)
boy_hair_length = gr.Dropdown(label="Boy Hair Length", choices=["Short", "Medium", "Long", "Bald"], value="Short")
boy_hair_style = gr.Dropdown(label="Boy Hair Style", choices=["None", "Undercut", "Mullet", "French Crop", "Slicked Back", "Fade", "Buzz Cut"], value="Undercut")
boy_hair_color = gr.Dropdown(label="Boy Hair Color", choices=["None", "Blonde", "Brown", "Black", "Brunette", "Redhead"], value="Black")
batch_size = gr.Slider(minimum=1, maximum=8, value=4, step=1, label="Batch Size")
generate_btn = gr.Button("Generate")
with gr.Row():
image_output = gr.Gallery(label="Generated Image Raw")
image_output_processed = gr.Gallery(label="Generated Image Final")
time_output = gr.Textbox(label="Processing Time")
generate_btn.click(
fn=generate_paris_images,
inputs=[product_name, image1_input, image2_input, woman_prompt, man_prompt, girl_name, girl_hair_length, girl_hair_style, girl_hair_color, boy_name, boy_hair_length, boy_hair_style, boy_hair_color, batch_size],
outputs=[image_output, image_output_processed, time_output]
)
product_name.change(
fn=change_product_preview,
inputs=[product_name],
outputs=[product_preview, woman_prompt, man_prompt]
)
if __name__ == "__main__":
print("Starting Gradio interface...")
demo.launch()