File size: 1,800 Bytes
56a9e12
6648f32
 
 
 
56a9e12
6648f32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9112e74
84e526f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
import torch
from process import load_seg_model, get_palette, generate_mask
from PIL import Image
import os

# Model initialization
device = "cuda" if torch.cuda.is_available() else "cpu"

def load_models():
    try:
        net = load_seg_model("model/cloth_segm.pth", device=device)
        palette = get_palette(4)
        return net, palette
    except Exception as e:
        raise gr.Error(f"Model failed to load: {str(e)}")

net, palette = load_models()

def predict(image):
    if image is None:
        raise gr.Error("Please upload or capture an image first")
    try:
        if not isinstance(image, Image.Image):
            image = Image.fromarray(image)
        return generate_mask(image, net=net, palette=palette, device=device)
    except Exception as e:
        raise gr.Error(f"Processing error: {str(e)}")

# Interface
with gr.Blocks(title="Cloth Segmentation") as demo:
    gr.Markdown("## πŸ‘• Cloth Segmentation Tool")
    
    with gr.Row():
        with gr.Column():
            img_input = gr.Image(sources=["upload", "webcam"], 
                               type="pil", 
                               label="Input Image")
            btn = gr.Button("Generate Mask", variant="primary")
        
        with gr.Column():
            img_output = gr.Image(label="Segmentation Result")
    
    # Optional examples
    if os.path.exists("examples"):
        gr.Examples(
            examples=[os.path.join("examples", f) for f in os.listdir("examples") 
                     if f.endswith(('.png','.jpg','.jpeg'))],
            inputs=img_input,
            outputs=img_output,
            fn=predict,
            cache_examples=True
        )

    btn.click(predict, inputs=img_input, outputs=img_output)

if __name__ == "__main__":
    demo.launch()