fffiloni commited on
Commit
3b79011
1 Parent(s): 4b76a05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -11
app.py CHANGED
@@ -12,24 +12,26 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
12
  def preprocess_image(image):
13
  return image, gr.State([]), gr.State([]), image
14
 
15
- def get_point(tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
16
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
17
 
18
  tracking_points.value.append(evt.index)
19
  print(f"TRACKING POINT: {tracking_points.value}")
20
 
21
- trackings_input_label.value.append(1)
 
 
 
22
  print(f"TRACKING INPUT LABEL: {trackings_input_label.value}")
23
- # for SAM2
24
- # input_point = np.array(tracking_points.value)
25
- # print(f"SAM2 INPUT POINT: {input_point}")
26
- # input_label = np.array([1])
27
-
28
  transparent_background = Image.open(first_frame_path).convert('RGBA')
29
  w, h = transparent_background.size
30
  transparent_layer = np.zeros((h, w, 4))
31
- for track in tracking_points.value:
32
- cv2.circle(transparent_layer, track, 5, (255, 0, 0, 255), -1)
 
 
 
33
 
34
  transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
35
  selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
@@ -143,13 +145,14 @@ with gr.Blocks() as demo:
143
  with gr.Row():
144
  input_image = gr.Image(label="input image", interactive=True, type="filepath")
145
  with gr.Column():
 
146
  points_map = gr.Image(label="points map", interactive=False)
147
  submit_btn = gr.Button("Submit")
148
- output_result = gr.Gallery()
149
 
150
  input_image.upload(preprocess_image, input_image, [first_frame_path, tracking_points, trackings_input_label, points_map])
151
 
152
- points_map.select(get_point, [tracking_points, trackings_input_label, first_frame_path], [tracking_points, trackings_input_label, points_map])
153
 
154
 
155
  submit_btn.click(
 
12
  def preprocess_image(image):
13
  return image, gr.State([]), gr.State([]), image
14
 
15
+ def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
16
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
17
 
18
  tracking_points.value.append(evt.index)
19
  print(f"TRACKING POINT: {tracking_points.value}")
20
 
21
+ if point_type == "include":
22
+ trackings_input_label.value.append(1)
23
+ elif point_type == "exclude":
24
+ trackings_input_label.value.append(0)
25
  print(f"TRACKING INPUT LABEL: {trackings_input_label.value}")
26
+
 
 
 
 
27
  transparent_background = Image.open(first_frame_path).convert('RGBA')
28
  w, h = transparent_background.size
29
  transparent_layer = np.zeros((h, w, 4))
30
+ for index, track in enumerate(tracking_points.value):
31
+ if trackings_input_label.value[index] == 1:
32
+ cv2.circle(transparent_layer, track, 5, (0, 0, 255, 255), -1)
33
+ else:
34
+ cv2.circle(transparent_layer, track, 5, (255, 0, 0, 255), -1)
35
 
36
  transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
37
  selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
 
145
  with gr.Row():
146
  input_image = gr.Image(label="input image", interactive=True, type="filepath")
147
  with gr.Column():
148
+ point_type = gr.Radio(label="point type", choices=["include", "exclude"] value="include")
149
  points_map = gr.Image(label="points map", interactive=False)
150
  submit_btn = gr.Button("Submit")
151
+ output_result = gr.Image()
152
 
153
  input_image.upload(preprocess_image, input_image, [first_frame_path, tracking_points, trackings_input_label, points_map])
154
 
155
+ points_map.select(get_point, [point_type, tracking_points, trackings_input_label, first_frame_path], [tracking_points, trackings_input_label, points_map])
156
 
157
 
158
  submit_btn.click(