Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import random | |
import torch | |
from diffusers import StableDiffusionPipeline | |
# Load the Stable Diffusion model for text-based garment generation | |
model_id = "runwayml/stable-diffusion-v1-5" | |
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) | |
pipe = pipe.to("cuda") # Use GPU for faster inference | |
MAX_SEED = 999999 | |
def generate_garment(person_img, cloth_description, seed, randomize_seed): | |
if person_img is None or cloth_description is None or cloth_description.strip() == "": | |
return None, None, "Invalid input" | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
# Generate garment image from the text description | |
torch.manual_seed(seed) | |
garment_img = pipe(cloth_description).images[0] | |
# Combine the generated garment with the person's image | |
result_img = combine_images(person_img, garment_img) | |
return result_img, seed, "Success" | |
def combine_images(person_img, garment_img): | |
person_img = np.array(person_img) | |
garment_img = np.array(garment_img.resize((person_img.shape[1], person_img.shape[0]))) | |
# Simple overlay of garment on the person image | |
# Further improvement may require segmentation/masking | |
result_img = np.where(garment_img[:, :, 3:] > 0, garment_img[:, :, :3], person_img) | |
return result_img | |
css = """ | |
#col-left { | |
margin: 0 auto; | |
max-width: 430px; | |
} | |
#col-mid { | |
margin: 0 auto; | |
max-width: 430px; | |
} | |
#col-right { | |
margin: 0 auto; | |
max-width: 430px; | |
} | |
#col-showcase { | |
margin: 0 auto; | |
max-width: 1100px; | |
} | |
""" | |
with gr.Blocks(css=css) as Tryon: | |
gr.HTML("<h1>Virtual Try-On with Text-based Garment Generation</h1>") | |
with gr.Row(): | |
with gr.Column(elem_id="col-left"): | |
gr.HTML("<h3>Step 1: Upload a person image ⬇️</h3>") | |
person_img = gr.Image(label="Person Image", source='upload', type="numpy") | |
with gr.Column(elem_id="col-mid"): | |
gr.HTML("<h3>Step 2: Describe the garment ⬇️</h3>") | |
cloth_description = gr.Textbox(label="Garment Description", placeholder="e.g., red dress with floral pattern") | |
with gr.Column(elem_id="col-right"): | |
gr.HTML("<h3>Step 3: Generate Try-On Image ⬇️</h3>") | |
result_img = gr.Image(label="Result", show_share_button=False) | |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) | |
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) | |
seed_used = gr.Number(label="Seed Used", interactive=False) | |
result_info = gr.Text(label="Status", interactive=False) | |
generate_button = gr.Button(value="Run") | |
generate_button.click(fn=generate_garment, | |
inputs=[person_img, cloth_description, seed, randomize_seed], | |
outputs=[result_img, seed_used, result_info]) | |
Tryon.launch() | |