try-on-kolor / app.py
rishh76's picture
Update app.py
3ada0c7 verified
raw
history blame
2.98 kB
import gradio as gr
import numpy as np
import random
import torch
from diffusers import StableDiffusionPipeline
# Load the Stable Diffusion model for text-based garment generation
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda") # Use GPU for faster inference
MAX_SEED = 999999
def generate_garment(person_img, cloth_description, seed, randomize_seed):
if person_img is None or cloth_description is None or cloth_description.strip() == "":
return None, None, "Invalid input"
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Generate garment image from the text description
torch.manual_seed(seed)
garment_img = pipe(cloth_description).images[0]
# Combine the generated garment with the person's image
result_img = combine_images(person_img, garment_img)
return result_img, seed, "Success"
def combine_images(person_img, garment_img):
person_img = np.array(person_img)
garment_img = np.array(garment_img.resize((person_img.shape[1], person_img.shape[0])))
# Simple overlay of garment on the person image
# Further improvement may require segmentation/masking
result_img = np.where(garment_img[:, :, 3:] > 0, garment_img[:, :, :3], person_img)
return result_img
css = """
#col-left {
margin: 0 auto;
max-width: 430px;
}
#col-mid {
margin: 0 auto;
max-width: 430px;
}
#col-right {
margin: 0 auto;
max-width: 430px;
}
#col-showcase {
margin: 0 auto;
max-width: 1100px;
}
"""
with gr.Blocks(css=css) as Tryon:
gr.HTML("<h1>Virtual Try-On with Text-based Garment Generation</h1>")
with gr.Row():
with gr.Column(elem_id="col-left"):
gr.HTML("<h3>Step 1: Upload a person image ⬇️</h3>")
person_img = gr.Image(label="Person Image", source='upload', type="numpy")
with gr.Column(elem_id="col-mid"):
gr.HTML("<h3>Step 2: Describe the garment ⬇️</h3>")
cloth_description = gr.Textbox(label="Garment Description", placeholder="e.g., red dress with floral pattern")
with gr.Column(elem_id="col-right"):
gr.HTML("<h3>Step 3: Generate Try-On Image ⬇️</h3>")
result_img = gr.Image(label="Result", show_share_button=False)
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
seed_used = gr.Number(label="Seed Used", interactive=False)
result_info = gr.Text(label="Status", interactive=False)
generate_button = gr.Button(value="Run")
generate_button.click(fn=generate_garment,
inputs=[person_img, cloth_description, seed, randomize_seed],
outputs=[result_img, seed_used, result_info])
Tryon.launch()