File size: 2,497 Bytes
896437a
e657276
896437a
e657276
 
896437a
e657276
 
 
896437a
c396ac7
e657276
c396ac7
 
e657276
896437a
e657276
 
 
 
 
6984480
e657276
 
 
 
 
 
 
6984480
 
896437a
e657276
 
 
 
 
 
 
9112e74
 
e657276
 
 
 
 
 
 
 
9112e74
e657276
 
 
 
c396ac7
e657276
 
 
 
 
 
 
 
 
9112e74
e657276
c396ac7
 
e657276
c396ac7
 
1369068
c396ac7
e657276
 
 
 
 
9112e74
c396ac7
e657276
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr
import torch
from process import load_seg_model, get_palette, generate_mask
from PIL import Image
import os

# Initialize model
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = "model/cloth_segm.pth"

try:
    net = load_seg_model(model_path, device=device)
    palette = get_palette(4)
except Exception as e:
    raise RuntimeError(f"Failed to load model: {str(e)}")

def process_image(input_img):
    """Process input image and return segmentation mask"""
    if input_img is None:
        raise gr.Error("Please upload or capture an image first")
    
    try:
        # Convert to PIL Image if it's not already
        if not isinstance(input_img, Image.Image):
            input_img = Image.fromarray(input_img)
        
        # Generate mask
        output_mask = generate_mask(input_img, net=net, palette=palette, device=device)
        return output_mask
    except Exception as e:
        raise gr.Error(f"Error processing image: {str(e)}")

# Create simple interface
with gr.Blocks(title="Cloth Segmentation") as demo:
    gr.Markdown("""
    # 🧥 Cloth Segmentation App
    Upload an image or capture from your camera to get segmentation results.
    """)
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(
                sources=["upload", "webcam"],
                type="pil",
                label="Input Image",
                interactive=True
            )
            submit_btn = gr.Button("Process", variant="primary")
        
        with gr.Column():
            output_image = gr.Image(
                label="Segmentation Result",
                interactive=False
            )
    
    # Examples section (optional)
    example_dir = "input"
    if os.path.exists(example_dir):
        example_images = [
            os.path.join(example_dir, f)
            for f in os.listdir(example_dir)
            if f.lower().endswith(('.png', '.jpg', '.jpeg'))
        ]
        
        gr.Examples(
            examples=example_images,
            inputs=[input_image],
            outputs=[output_image],
            fn=process_image,
            cache_examples=True,
            label="Example Images"
        )
    
    submit_btn.click(
        fn=process_image,
        inputs=input_image,
        outputs=output_image
    )

# Launch with appropriate settings
if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        show_error=True
    )