File size: 1,800 Bytes
56a9e12 6648f32 56a9e12 6648f32 9112e74 84e526f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import gradio as gr
import torch
from process import load_seg_model, get_palette, generate_mask
from PIL import Image
import os
# Model initialization
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_models():
try:
net = load_seg_model("model/cloth_segm.pth", device=device)
palette = get_palette(4)
return net, palette
except Exception as e:
raise gr.Error(f"Model failed to load: {str(e)}")
net, palette = load_models()
def predict(image):
if image is None:
raise gr.Error("Please upload or capture an image first")
try:
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
return generate_mask(image, net=net, palette=palette, device=device)
except Exception as e:
raise gr.Error(f"Processing error: {str(e)}")
# Interface
with gr.Blocks(title="Cloth Segmentation") as demo:
gr.Markdown("## π Cloth Segmentation Tool")
with gr.Row():
with gr.Column():
img_input = gr.Image(sources=["upload", "webcam"],
type="pil",
label="Input Image")
btn = gr.Button("Generate Mask", variant="primary")
with gr.Column():
img_output = gr.Image(label="Segmentation Result")
# Optional examples
if os.path.exists("examples"):
gr.Examples(
examples=[os.path.join("examples", f) for f in os.listdir("examples")
if f.endswith(('.png','.jpg','.jpeg'))],
inputs=img_input,
outputs=img_output,
fn=predict,
cache_examples=True
)
btn.click(predict, inputs=img_input, outputs=img_output)
if __name__ == "__main__":
demo.launch() |