wildoctopus commited on
Commit
e657276
·
verified ·
1 Parent(s): 397297d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -39
app.py CHANGED
@@ -1,64 +1,86 @@
1
- import PIL
2
- import torch
3
  import gradio as gr
4
- import os
5
  from process import load_seg_model, get_palette, generate_mask
 
 
6
 
7
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
8
 
9
- # Initialize models
10
  try:
11
- checkpoint_path = 'model/cloth_segm.pth'
12
- if not os.path.exists(checkpoint_path):
13
- raise FileNotFoundError(f"Model checkpoint not found at {checkpoint_path}")
14
- net = load_seg_model(checkpoint_path, device=device)
15
  palette = get_palette(4)
16
  except Exception as e:
17
- raise RuntimeError(f"Failed to initialize models: {str(e)}")
18
 
19
- def run(img):
20
- if img is None:
21
- raise gr.Error("Please upload an image first")
 
 
22
  try:
23
- return generate_mask(img, net=net, palette=palette, device=device)
 
 
 
 
 
 
24
  except Exception as e:
25
  raise gr.Error(f"Error processing image: {str(e)}")
26
 
27
- # Handle examples
28
- image_dir = 'input'
29
- examples = []
30
- if os.path.exists(image_dir):
31
- examples = [
32
- [os.path.join(image_dir, f)]
33
- for f in sorted(os.listdir(image_dir))
34
- if f.lower().endswith(('.png', '.jpg', '.jpeg'))
35
- ]
36
-
37
- # Create interface
38
- with gr.Blocks() as demo:
39
  with gr.Row():
40
  with gr.Column():
41
- input_image = gr.Image(label="Input Image", type="pil")
 
 
 
 
 
 
 
42
  with gr.Column():
43
- output_image = gr.Image(label="Segmentation Result")
 
 
 
44
 
45
- with gr.Row():
 
 
 
 
 
 
 
 
46
  gr.Examples(
47
- examples=examples,
48
  inputs=[input_image],
49
  outputs=[output_image],
50
- fn=run,
51
  cache_examples=True,
52
  label="Example Images"
53
  )
54
 
55
- submit_btn = gr.Button("Segment", variant="primary")
56
- submit_btn.click(fn=run, inputs=input_image, outputs=output_image)
 
 
 
57
 
58
  # Launch with appropriate settings
59
- try:
60
- demo.launch(server_name="0.0.0.0", server_port=7860)
61
- except Exception as e:
62
- print(f"Error launching app: {str(e)}")
63
- # Fallback with sharing enabled
64
- demo.launch(share=True)
 
 
 
1
  import gradio as gr
2
+ import torch
3
  from process import load_seg_model, get_palette, generate_mask
4
+ from PIL import Image
5
+ import os
6
 
7
+ # Initialize model
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ model_path = "model/cloth_segm.pth"
10
 
 
11
  try:
12
+ net = load_seg_model(model_path, device=device)
 
 
 
13
  palette = get_palette(4)
14
  except Exception as e:
15
+ raise RuntimeError(f"Failed to load model: {str(e)}")
16
 
17
+ def process_image(input_img):
18
+ """Process input image and return segmentation mask"""
19
+ if input_img is None:
20
+ raise gr.Error("Please upload or capture an image first")
21
+
22
  try:
23
+ # Convert to PIL Image if it's not already
24
+ if not isinstance(input_img, Image.Image):
25
+ input_img = Image.fromarray(input_img)
26
+
27
+ # Generate mask
28
+ output_mask = generate_mask(input_img, net=net, palette=palette, device=device)
29
+ return output_mask
30
  except Exception as e:
31
  raise gr.Error(f"Error processing image: {str(e)}")
32
 
33
+ # Create simple interface
34
+ with gr.Blocks(title="Cloth Segmentation") as demo:
35
+ gr.Markdown("""
36
+ # 🧥 Cloth Segmentation App
37
+ Upload an image or capture from your camera to get segmentation results.
38
+ """)
39
+
 
 
 
 
 
40
  with gr.Row():
41
  with gr.Column():
42
+ input_image = gr.Image(
43
+ sources=["upload", "webcam"],
44
+ type="pil",
45
+ label="Input Image",
46
+ interactive=True
47
+ )
48
+ submit_btn = gr.Button("Process", variant="primary")
49
+
50
  with gr.Column():
51
+ output_image = gr.Image(
52
+ label="Segmentation Result",
53
+ interactive=False
54
+ )
55
 
56
+ # Examples section (optional)
57
+ example_dir = "input"
58
+ if os.path.exists(example_dir):
59
+ example_images = [
60
+ os.path.join(example_dir, f)
61
+ for f in os.listdir(example_dir)
62
+ if f.lower().endswith(('.png', '.jpg', '.jpeg'))
63
+ ]
64
+
65
  gr.Examples(
66
+ examples=example_images,
67
  inputs=[input_image],
68
  outputs=[output_image],
69
+ fn=process_image,
70
  cache_examples=True,
71
  label="Example Images"
72
  )
73
 
74
+ submit_btn.click(
75
+ fn=process_image,
76
+ inputs=input_image,
77
+ outputs=output_image
78
+ )
79
 
80
  # Launch with appropriate settings
81
+ if __name__ == "__main__":
82
+ demo.launch(
83
+ server_name="0.0.0.0",
84
+ server_port=7860,
85
+ show_error=True
86
+ )