File size: 2,497 Bytes
896437a e657276 896437a e657276 896437a e657276 896437a c396ac7 e657276 c396ac7 e657276 896437a e657276 6984480 e657276 6984480 896437a e657276 9112e74 e657276 9112e74 e657276 c396ac7 e657276 9112e74 e657276 c396ac7 e657276 c396ac7 1369068 c396ac7 e657276 9112e74 c396ac7 e657276 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import gradio as gr
import torch
from process import load_seg_model, get_palette, generate_mask
from PIL import Image
import os
# Initialize model
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = "model/cloth_segm.pth"
try:
net = load_seg_model(model_path, device=device)
palette = get_palette(4)
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}")
def process_image(input_img):
"""Process input image and return segmentation mask"""
if input_img is None:
raise gr.Error("Please upload or capture an image first")
try:
# Convert to PIL Image if it's not already
if not isinstance(input_img, Image.Image):
input_img = Image.fromarray(input_img)
# Generate mask
output_mask = generate_mask(input_img, net=net, palette=palette, device=device)
return output_mask
except Exception as e:
raise gr.Error(f"Error processing image: {str(e)}")
# Create simple interface
with gr.Blocks(title="Cloth Segmentation") as demo:
gr.Markdown("""
# 🧥 Cloth Segmentation App
Upload an image or capture from your camera to get segmentation results.
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(
sources=["upload", "webcam"],
type="pil",
label="Input Image",
interactive=True
)
submit_btn = gr.Button("Process", variant="primary")
with gr.Column():
output_image = gr.Image(
label="Segmentation Result",
interactive=False
)
# Examples section (optional)
example_dir = "input"
if os.path.exists(example_dir):
example_images = [
os.path.join(example_dir, f)
for f in os.listdir(example_dir)
if f.lower().endswith(('.png', '.jpg', '.jpeg'))
]
gr.Examples(
examples=example_images,
inputs=[input_image],
outputs=[output_image],
fn=process_image,
cache_examples=True,
label="Example Images"
)
submit_btn.click(
fn=process_image,
inputs=input_image,
outputs=output_image
)
# Launch with appropriate settings
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
) |