Peng Shiya commited on
Commit
38277a1
1 Parent(s): cba1a87

feature: separate annotation and cutout

Browse files
Files changed (4) hide show
  1. .gitignore +2 -1
  2. app.py +26 -19
  3. app_configs.py +2 -1
  4. service.py +23 -1
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  __pycache__/
2
- model/
 
 
1
  __pycache__/
2
+ model/
3
+ flagged/
app.py CHANGED
@@ -33,31 +33,31 @@ with block:
33
  return []
34
  def point_labels_empty():
35
  return []
 
36
  point_coords = gr.State(point_coords_empty)
37
  point_labels = gr.State(point_labels_empty)
38
- raw_image = gr.Image(type='pil', visible=False)
 
39
 
40
  # UI
41
- with gr.Row():
42
- with gr.Column():
43
  input_image = gr.Image(label='Input', height=512, type='pil')
44
- with gr.Row():
45
- point_label_radio = gr.Radio(label='Point Label', choices=[1,0], value=1)
46
- reset_btn = gr.Button('Reset')
47
- run_btn = gr.Button('Run', variant = 'primary')
48
- gr.Examples(examples=[['examples/cat-256.png','examples/cat-256.png']],inputs=[input_image, raw_image])
49
- with gr.Column():
50
- with gr.Tab('Cutout'):
51
- cutout_gallery = gr.Gallery()
52
- with gr.Tab('Annotation'):
53
- masks_annotated_image = gr.AnnotatedImage(label='Segments')
54
-
55
  # components
56
- components = {point_coords, point_labels, raw_image, input_image, point_label_radio, reset_btn, run_btn, cutout_gallery, masks_annotated_image}
 
 
57
 
58
  # event - init coords
59
  def on_reset_btn_click(raw_image):
60
- return raw_image, point_coords_empty(), point_labels_empty(), None
61
  reset_btn.click(on_reset_btn_click, [raw_image], [input_image, point_coords, point_labels], queue=False)
62
 
63
  def on_input_image_upload(input_image):
@@ -91,9 +91,16 @@ with block:
91
  point_coords=np.array(inputs[point_coords]),
92
  point_labels=np.array(inputs[point_labels]))
93
  annotated = (image, [(masks[i], f'Mask {i}') for i in range(len(masks))])
94
- cutouts = [service.cutout(image, mask) for mask in masks]
95
- return cutouts, annotated
96
- run_btn.click(on_run_btn_click, components, [cutout_gallery, masks_annotated_image], queue=True)
 
 
 
 
 
 
 
97
 
98
  if __name__ == '__main__':
99
  block.queue()
 
33
  return []
34
  def point_labels_empty():
35
  return []
36
+ raw_image = gr.Image(type='pil', visible=False)
37
  point_coords = gr.State(point_coords_empty)
38
  point_labels = gr.State(point_labels_empty)
39
+ masks = gr.State()
40
+ cutout_idx = gr.State(set())
41
 
42
  # UI
43
+ with gr.Column():
44
+ with gr.Row():
45
  input_image = gr.Image(label='Input', height=512, type='pil')
46
+ masks_annotated_image = gr.AnnotatedImage(label='Segments')
47
+ with gr.Row():
48
+ point_label_radio = gr.Radio(label='Point Label', choices=[1,0], value=1)
49
+ reset_btn = gr.Button('Reset')
50
+ run_btn = gr.Button('Run', variant = 'primary')
51
+ cutout_galary = gr.Gallery(label='Cutouts', object_fit='contain')
52
+
 
 
 
 
53
  # components
54
+ components = {
55
+ point_coords, point_labels, raw_image, masks, cutout_idx,
56
+ input_image, point_label_radio, reset_btn, run_btn, masks_annotated_image}
57
 
58
  # event - init coords
59
  def on_reset_btn_click(raw_image):
60
+ return raw_image, point_coords_empty(), point_labels_empty(), None, []
61
  reset_btn.click(on_reset_btn_click, [raw_image], [input_image, point_coords, point_labels], queue=False)
62
 
63
  def on_input_image_upload(input_image):
 
91
  point_coords=np.array(inputs[point_coords]),
92
  point_labels=np.array(inputs[point_labels]))
93
  annotated = (image, [(masks[i], f'Mask {i}') for i in range(len(masks))])
94
+ return annotated, masks, set()
95
+ run_btn.click(on_run_btn_click, components, [masks_annotated_image, masks, cutout_idx], queue=True)
96
+
97
+ # event - get cutout
98
+ def on_masks_annotated_image_select(inputs, evt:gr.SelectData):
99
+ inputs[cutout_idx].add(evt.index)
100
+ cutouts = [service.cutout(inputs[raw_image], inputs[masks][idx]) for idx in list(inputs[cutout_idx])]
101
+ tight_cutouts = [service.crop_empty(cutout) for cutout in cutouts]
102
+ return inputs[cutout_idx], tight_cutouts
103
+ masks_annotated_image.select(on_masks_annotated_image_select, components, [cutout_idx, cutout_galary])
104
 
105
  if __name__ == '__main__':
106
  block.queue()
app_configs.py CHANGED
@@ -2,4 +2,5 @@ model_type = r'vit_b'
2
  # model_ckpt_path = None
3
  model_ckpt_path = "checkpoints/sam_vit_b_01ec64.pth"
4
  device = 'cpu'
5
- enable_segment_all = False
 
 
2
  # model_ckpt_path = None
3
  model_ckpt_path = "checkpoints/sam_vit_b_01ec64.pth"
4
  device = 'cpu'
5
+ enable_segment_all = False
6
+ flagging_dir = r'.\flagged'
service.py CHANGED
@@ -1,4 +1,5 @@
1
  from typing import IO, List
 
2
  import torch
3
  from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
4
  from PIL import Image
@@ -87,4 +88,25 @@ def box_pts_to_xyxy(pt1, pt2):
87
  """
88
  x1, y1 = pt1
89
  x2, y2 = pt2
90
- return (min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import IO, List
2
+ import cv2
3
  import torch
4
  from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
5
  from PIL import Image
 
88
  """
89
  x1, y1 = pt1
90
  x2, y2 = pt2
91
+ return (min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2))
92
+
93
+ def crop_empty(image:Image.Image):
94
+ # Convert image to numpy array
95
+ np_image = np.array(image)
96
+
97
+ # Find non-transparent pixels
98
+ non_transparent_pixels = np_image[:, :, 3] > 0
99
+
100
+ # Calculate bounding box coordinates
101
+ rows = np.any(non_transparent_pixels, axis=1)
102
+ cols = np.any(non_transparent_pixels, axis=0)
103
+ ymin, ymax = np.where(rows)[0][[0, -1]]
104
+ xmin, xmax = np.where(cols)[0][[0, -1]]
105
+
106
+ # Crop the image
107
+ cropped_image = np_image[ymin:ymax+1, xmin:xmax+1, :]
108
+
109
+ # Convert cropped image back to PIL image
110
+ pil_image = Image.fromarray(cropped_image)
111
+
112
+ return pil_image