fffiloni commited on
Commit
0c327c7
·
verified ·
1 Parent(s): ecdf6ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -24
app.py CHANGED
@@ -29,9 +29,9 @@ def get_point(point_type, tracking_points, trackings_input_label, first_frame_pa
29
  transparent_layer = np.zeros((h, w, 4))
30
  for index, track in enumerate(tracking_points.value):
31
  if trackings_input_label.value[index] == 1:
32
- cv2.circle(transparent_layer, track, 5, (0, 0, 255, 255), -1)
33
  else:
34
- cv2.circle(transparent_layer, track, 5, (255, 0, 0, 255), -1)
35
 
36
  transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
37
  selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
@@ -73,37 +73,53 @@ def show_box(box, ax):
73
  w, h = box[2] - box[0], box[3] - box[1]
74
  ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
75
 
76
- def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=False):
77
- masks_store = []
 
 
78
  for i, (mask, score) in enumerate(zip(masks, scores)):
 
79
  plt.figure(figsize=(10, 10))
80
  plt.imshow(image)
81
- show_mask(mask, plt.gca(), borders=borders)
82
-
83
- """
84
  if point_coords is not None:
85
  assert input_labels is not None
86
  show_points(point_coords, input_labels, plt.gca())
87
- """
88
-
89
  if box_coords is not None:
90
- # boxes
91
  show_box(box_coords, plt.gca())
92
  if len(scores) > 1:
93
  plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
94
  plt.axis('off')
95
- # plt.show()
96
 
97
  # Save the figure as a JPG file
98
- filename = f"masked_image_{i+1}.jpg"
99
- plt.savefig(filename, format='jpg', bbox_inches='tight')
 
 
 
100
 
101
- masks_store.append(filename)
102
-
103
- # Close the figure to free up memory
104
- plt.close()
 
 
 
 
 
 
 
 
 
105
 
106
- return masks_store
 
 
 
 
 
 
 
107
 
108
  def sam_process(input_image, tracking_points, trackings_input_label):
109
  image = Image.open(input_image)
@@ -135,10 +151,10 @@ def sam_process(input_image, tracking_points, trackings_input_label):
135
 
136
  print(masks.shape)
137
 
138
- results = show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label, borders=True)
139
  print(results)
140
 
141
- return results[0]
142
 
143
  with gr.Blocks() as demo:
144
  first_frame_path = gr.State()
@@ -155,20 +171,33 @@ with gr.Blocks() as demo:
155
  points_map = gr.Image(label="points map", interactive=False)
156
  submit_btn = gr.Button("Submit")
157
  output_result = gr.Image()
 
158
 
159
  clear_points_btn.click(
160
  fn = preprocess_image,
161
  inputs = input_image,
162
- outputs = [first_frame_path, tracking_points, trackings_input_label, points_map]
 
 
 
 
 
 
 
 
163
  )
164
- input_image.upload(preprocess_image, input_image, [first_frame_path, tracking_points, trackings_input_label, points_map])
165
 
166
- points_map.select(get_point, [point_type, tracking_points, trackings_input_label, first_frame_path], [tracking_points, trackings_input_label, points_map])
 
 
 
 
 
167
 
168
 
169
  submit_btn.click(
170
  fn = sam_process,
171
  inputs = [input_image, tracking_points, trackings_input_label],
172
- outputs = [output_result]
173
  )
174
  demo.launch()
 
29
  transparent_layer = np.zeros((h, w, 4))
30
  for index, track in enumerate(tracking_points.value):
31
  if trackings_input_label.value[index] == 1:
32
+ cv2.circle(transparent_layer, track, 20, (0, 0, 255, 255), -1)
33
  else:
34
+ cv2.circle(transparent_layer, track, 20, (255, 0, 0, 255), -1)
35
 
36
  transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
37
  selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
 
73
  w, h = box[2] - box[0], box[3] - box[1]
74
  ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
75
 
76
+ def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
77
+ combined_images = [] # List to store filenames of images with masks overlaid
78
+ mask_images = [] # List to store filenames of separate mask images
79
+
80
  for i, (mask, score) in enumerate(zip(masks, scores)):
81
+ # ---- Original Image with Mask Overlaid ----
82
  plt.figure(figsize=(10, 10))
83
  plt.imshow(image)
84
+ show_mask(mask, plt.gca(), borders=borders) # Draw the mask with borders
 
 
85
  if point_coords is not None:
86
  assert input_labels is not None
87
  show_points(point_coords, input_labels, plt.gca())
 
 
88
  if box_coords is not None:
 
89
  show_box(box_coords, plt.gca())
90
  if len(scores) > 1:
91
  plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
92
  plt.axis('off')
 
93
 
94
  # Save the figure as a JPG file
95
+ combined_filename = f"combined_image_{i+1}.jpg"
96
+ plt.savefig(combined_filename, format='jpg', bbox_inches='tight')
97
+ combined_images.append(combined_filename)
98
+
99
+ plt.close() # Close the figure to free up memory
100
 
101
+ # ---- Separate Mask Image ----
102
+ plt.figure(figsize=(10, 10))
103
+ mask_image = np.zeros_like(image, dtype=np.uint8) # Initialize a blank image
104
+ show_mask(mask, plt.gca(), borders=False) # Draw the mask without borders
105
+
106
+ plt.axis('off')
107
+ plt.tight_layout()
108
+ plt.gca().set_axis_off()
109
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0,
110
+ hspace=0, wspace=0)
111
+ plt.margins(0, 0)
112
+ plt.gca().xaxis.set_major_locator(plt.NullLocator())
113
+ plt.gca().yaxis.set_major_locator(plt.NullLocator())
114
 
115
+ # Save mask image
116
+ mask_filename = f"mask_image_{i+1}.png"
117
+ plt.savefig(mask_filename, format='png', bbox_inches='tight', pad_inches=0)
118
+ mask_images.append(mask_filename)
119
+
120
+ plt.close() # Close the figure to free up memory
121
+
122
+ return combined_images, mask_images
123
 
124
  def sam_process(input_image, tracking_points, trackings_input_label):
125
  image = Image.open(input_image)
 
151
 
152
  print(masks.shape)
153
 
154
+ results, mask_results = show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label, borders=False)
155
  print(results)
156
 
157
+ return results[0], mask_results[0]
158
 
159
  with gr.Blocks() as demo:
160
  first_frame_path = gr.State()
 
171
  points_map = gr.Image(label="points map", interactive=False)
172
  submit_btn = gr.Button("Submit")
173
  output_result = gr.Image()
174
+ output_result_mask = gr.Image()
175
 
176
  clear_points_btn.click(
177
  fn = preprocess_image,
178
  inputs = input_image,
179
+ outputs = [first_frame_path, tracking_points, trackings_input_label, points_map],
180
+ queue=False
181
+ )
182
+
183
+ input_image.upload(
184
+ preprocess_image,
185
+ input_image,
186
+ [first_frame_path, tracking_points, trackings_input_label, points_map],
187
+ queue=False
188
  )
 
189
 
190
+ points_map.select(
191
+ get_point,
192
+ [point_type, tracking_points, trackings_input_label, first_frame_path],
193
+ [tracking_points, trackings_input_label, points_map],
194
+ queue=False
195
+ )
196
 
197
 
198
  submit_btn.click(
199
  fn = sam_process,
200
  inputs = [input_image, tracking_points, trackings_input_label],
201
+ outputs = [output_result, output_result_mask]
202
  )
203
  demo.launch()