fffiloni commited on
Commit
e8b186a
1 Parent(s): 13f07b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -11
app.py CHANGED
@@ -92,9 +92,11 @@ def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_l
92
  plt.figure(figsize=(10, 10))
93
  plt.imshow(image)
94
  show_mask(mask, plt.gca(), borders=borders) # Draw the mask with borders
 
95
  if point_coords is not None:
96
  assert input_labels is not None
97
  show_points(point_coords, input_labels, plt.gca())
 
98
  if box_coords is not None:
99
  show_box(box_coords, plt.gca())
100
  if len(scores) > 1:
@@ -127,6 +129,7 @@ def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_l
127
 
128
  return combined_images, mask_images
129
 
 
130
  def sam_process(input_image, checkpoint, tracking_points, trackings_input_label):
131
  image = Image.open(input_image)
132
  image = np.array(image.convert("RGB"))
@@ -167,7 +170,7 @@ def sam_process(input_image, checkpoint, tracking_points, trackings_input_label)
167
 
168
  print(masks.shape)
169
 
170
- results, mask_results = show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label, borders=False)
171
  print(results)
172
 
173
  return results[0], mask_results[0]
@@ -211,23 +214,23 @@ with gr.Blocks() as demo:
211
  )
212
 
213
  points_map.upload(
214
- preprocess_image,
215
- points_map,
216
- [first_frame_path, tracking_points, trackings_input_label, input_image],
217
- queue=False
218
  )
219
 
220
  points_map.select(
221
- get_point,
222
- [point_type, tracking_points, trackings_input_label, first_frame_path],
223
- [tracking_points, trackings_input_label, points_map],
224
- queue=False
225
  )
226
 
227
-
228
  submit_btn.click(
229
  fn = sam_process,
230
  inputs = [input_image, checkpoint, tracking_points, trackings_input_label],
231
  outputs = [output_result, output_result_mask]
232
  )
233
- demo.launch()
 
 
92
  plt.figure(figsize=(10, 10))
93
  plt.imshow(image)
94
  show_mask(mask, plt.gca(), borders=borders) # Draw the mask with borders
95
+ """
96
  if point_coords is not None:
97
  assert input_labels is not None
98
  show_points(point_coords, input_labels, plt.gca())
99
+ """
100
  if box_coords is not None:
101
  show_box(box_coords, plt.gca())
102
  if len(scores) > 1:
 
129
 
130
  return combined_images, mask_images
131
 
132
+ @spaces.GPU()
133
  def sam_process(input_image, checkpoint, tracking_points, trackings_input_label):
134
  image = Image.open(input_image)
135
  image = np.array(image.convert("RGB"))
 
170
 
171
  print(masks.shape)
172
 
173
+ results, mask_results = show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label, borders=True)
174
  print(results)
175
 
176
  return results[0], mask_results[0]
 
214
  )
215
 
216
  points_map.upload(
217
+ fn = preprocess_image,
218
+ inputs = [points_map],
219
+ outputs = [first_frame_path, tracking_points, trackings_input_label, input_image],
220
+ queue = False
221
  )
222
 
223
  points_map.select(
224
+ fn = get_point,
225
+ inputs = [point_type, tracking_points, trackings_input_label, first_frame_path],
226
+ outputs = [tracking_points, trackings_input_label, points_map],
227
+ queue = False
228
  )
229
 
 
230
  submit_btn.click(
231
  fn = sam_process,
232
  inputs = [input_image, checkpoint, tracking_points, trackings_input_label],
233
  outputs = [output_result, output_result_mask]
234
  )
235
+
236
+ demo.launch(show_api=False, show_error=True)