DeXFit-TryOn / app.py
Vaibhavnaik12's picture
Update app.py
7627f49 verified
raw
history blame
7.77 kB
import argparse
import os
from datetime import datetime
import gradio as gr
import numpy as np
import torch
from diffusers.image_processor import VaeImageProcessor
from huggingface_hub import snapshot_download
from PIL import Image
from model.cloth_masker import AutoMasker, vis_mask
from model.pipeline import CatVTONPipeline
from utils import init_weight_dtype, resize_and_crop, resize_and_padding
# ... (rest of your imports and function definitions remain unchanged)
HEADER = """
<p style="text-align: center;">
<img src="resource/DeXFIT.png" alt="DeX Logo" style="height: 100px;">
</p>
<h1 style="text-align: center; color: #101820;"> DEX FIT Virtual Try-On with Diffusion Models </h1>
<br>
<p style="color: #101820;">· SafetyChecker is set to filter NSFW content, but it may block normal results too. Please adjust the <span style="color: #00685E;">seed</span> for normal outcomes.</p>
"""
def app_gradio():
with gr.Blocks(title="CatVTON", css="#main {background-color: #F4F4F1;}") as demo:
gr.Markdown(HEADER)
with gr.Row():
with gr.Column(scale=1, min_width=350):
with gr.Row():
person_image = gr.ImageEditor(
interactive=True, label="Person Image", type="filepath"
)
with gr.Row():
with gr.Column(scale=1, min_width=230):
cloth_image = gr.Image(
interactive=True, label="Condition Image", type="filepath"
)
with gr.Column(scale=1, min_width=120):
gr.Markdown(
'<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
)
cloth_type = gr.Radio(
label="Try-On Cloth Type",
choices=["upper", "lower", "overall"],
value="upper",
label_style={"color": "#101820"}
)
submit = gr.Button("Submit", elem_id="submit-button", style={"background-color": "#00685E", "color": "#FFFFFF"})
gr.Markdown(
'<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
)
gr.Markdown(
'<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
)
with gr.Accordion("Advanced Options", open=False):
num_inference_steps = gr.Slider(
label="Inference Step", minimum=10, maximum=100, step=5, value=50
)
guidance_scale = gr.Slider(
label="CFG Strength", minimum=0.0, maximum=7.5, step=0.5, value=2.5
)
seed = gr.Slider(
label="Seed", minimum=-1, maximum=10000, step=1, value=1000
)
show_type = gr.Radio(
label="Show Type",
choices=["result only", "input & result", "input & mask & result"],
value="input & mask & result",
)
with gr.Column(scale=2, min_width=500):
result_image = gr.Image(interactive=False, label="Result")
with gr.Row():
root_path = "resource/demo/example"
with gr.Column():
men_exm = gr.Examples(
examples=[
os.path.join(root_path, "person", "men", _)
for _ in os.listdir(os.path.join(root_path, "person", "men"))
],
examples_per_page=4,
inputs=person_image,
label="Person Examples ①",
)
women_exm = gr.Examples(
examples=[
os.path.join(root_path, "person", "women", _)
for _ in os.listdir(os.path.join(root_path, "person", "women"))
],
examples
examples_per_page=4,
inputs=person_image,
label="Person Examples ②",
)
gr.Markdown(
'<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion" style="color: #00685E;">OOTDiffusion</a> and <a href="https://www.outfitanyone.org" style="color: #00685E;">OutfitAnyone</a>.</span>'
)
with gr.Column():
condition_upper_exm = gr.Examples(
examples=[
os.path.join(root_path, "condition", "upper", _)
for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
],
examples_per_page=4,
inputs=cloth_image,
label="Condition Upper Examples",
)
condition_overall_exm = gr.Examples(
examples=[
os.path.join(root_path, "condition", "overall", _)
for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
],
examples_per_page=4,
inputs=cloth_image,
label="Condition Overall Examples",
)
condition_person_exm = gr.Examples(
examples=[
os.path.join(root_path, "condition", "person", _)
for _ in os.listdir(os.path.join(root_path, "condition", "person"))
],
examples_per_page=4,
inputs=cloth_image,
label="Condition Reference Person Examples",
)
gr.Markdown(
'<span style="color: #808080; font-size: small;">*Condition examples come from the Internet.</span>'
)
# Update the image path change function
image_path.change(
person_example_fn, inputs=image_path, outputs=person_image
)
# Connect the submit button to the function
submit.click(
submit_function,
[
person_image,
cloth_image,
cloth_type,
num_inference_steps,
guidance_scale,
seed,
show_type,
],
result_image,
)
demo.queue().launch(share=True, show_error=True)
if __name__ == "__main__":
app_gradio()