|
import gradio as gr |
|
import torch |
|
from process import load_seg_model, get_palette, generate_mask |
|
from PIL import Image |
|
import os |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
def load_models(): |
|
try: |
|
net = load_seg_model("model/cloth_segm.pth", device=device) |
|
palette = get_palette(4) |
|
return net, palette |
|
except Exception as e: |
|
raise gr.Error(f"Model failed to load: {str(e)}") |
|
|
|
net, palette = load_models() |
|
|
|
def predict(image): |
|
if image is None: |
|
raise gr.Error("Please upload or capture an image first") |
|
try: |
|
if not isinstance(image, Image.Image): |
|
image = Image.fromarray(image) |
|
return generate_mask(image, net=net, palette=palette, device=device) |
|
except Exception as e: |
|
raise gr.Error(f"Processing error: {str(e)}") |
|
|
|
|
|
with gr.Blocks(title="Cloth Segmentation") as demo: |
|
gr.Markdown("## π Cloth Segmentation Tool") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
img_input = gr.Image(sources=["upload", "webcam"], |
|
type="pil", |
|
label="Input Image") |
|
btn = gr.Button("Generate Mask", variant="primary") |
|
|
|
with gr.Column(): |
|
img_output = gr.Image(label="Segmentation Result") |
|
|
|
|
|
if os.path.exists("examples"): |
|
gr.Examples( |
|
examples=[os.path.join("examples", f) for f in os.listdir("examples") |
|
if f.endswith(('.png','.jpg','.jpeg'))], |
|
inputs=img_input, |
|
outputs=img_output, |
|
fn=predict, |
|
cache_examples=True |
|
) |
|
|
|
btn.click(predict, inputs=img_input, outputs=img_output) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |