wildoctopus's picture
Update app.py
6648f32 verified
raw
history blame
1.8 kB
import gradio as gr
import torch
from process import load_seg_model, get_palette, generate_mask
from PIL import Image
import os
# Model initialization
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_models():
try:
net = load_seg_model("model/cloth_segm.pth", device=device)
palette = get_palette(4)
return net, palette
except Exception as e:
raise gr.Error(f"Model failed to load: {str(e)}")
net, palette = load_models()
def predict(image):
if image is None:
raise gr.Error("Please upload or capture an image first")
try:
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
return generate_mask(image, net=net, palette=palette, device=device)
except Exception as e:
raise gr.Error(f"Processing error: {str(e)}")
# Interface
with gr.Blocks(title="Cloth Segmentation") as demo:
gr.Markdown("## πŸ‘• Cloth Segmentation Tool")
with gr.Row():
with gr.Column():
img_input = gr.Image(sources=["upload", "webcam"],
type="pil",
label="Input Image")
btn = gr.Button("Generate Mask", variant="primary")
with gr.Column():
img_output = gr.Image(label="Segmentation Result")
# Optional examples
if os.path.exists("examples"):
gr.Examples(
examples=[os.path.join("examples", f) for f in os.listdir("examples")
if f.endswith(('.png','.jpg','.jpeg'))],
inputs=img_input,
outputs=img_output,
fn=predict,
cache_examples=True
)
btn.click(predict, inputs=img_input, outputs=img_output)
if __name__ == "__main__":
demo.launch()