Jiecong Lin commited on
Commit
e195352
·
1 Parent(s): e90d42d

init commit

Browse files
Files changed (2) hide show
  1. Dockerfile +28 -0
  2. app.py +572 -0
Dockerfile ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.5.1-cudnn-runtime-ubuntu22.04
2
+ ENV DEBIAN_FRONTEND=noninteractive
3
+ RUN apt-get update && apt-get install -y \
4
+ python3.11 \
5
+ python3.11-venv \
6
+ python3-pip \
7
+ git \
8
+ software-properties-common \
9
+ wget \
10
+ && apt-get clean \
11
+ && rm -rf /var/lib/apt/lists/*
12
+ RUN add-apt-repository -y -r ppa:jonathonf/ffmpeg-4 \
13
+ && apt-get update \
14
+ && apt-get install -y ffmpeg \
15
+ && apt-get clean \
16
+ && rm -rf /var/lib/apt/lists/*
17
+ WORKDIR /app
18
+ RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
19
+ RUN pip3 install numpy matplotlib pillow gradio==3.38.0 opencv-python ffmpeg-python
20
+ RUN git clone https://github.com/facebookresearch/segment-anything-2.git
21
+ WORKDIR /app/segment-anything-2
22
+ RUN pip3 install -e .
23
+ # RUN pip3 install -e ".[demo]"
24
+ WORKDIR /app/segment-anything-2/checkpoints
25
+ RUN ./download_ckpts.sh
26
+ WORKDIR /app/segment-anything-2
27
+ COPY app.py /app/segment-anything-2/app.py
28
+ CMD ["python3", "app.py"]
app.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import subprocess
2
+ # import re
3
+ # from typing import List, Tuple, Optional
4
+
5
+ # # Define the command to be executed
6
+ # command = ["python", "setup.py", "build_ext", "--inplace"]
7
+
8
+ # # Execute the command
9
+ # result = subprocess.run(command, capture_output=True, text=True)
10
+
11
+ # # Print the output and error (if any)
12
+ # print("Output:\n", result.stdout)
13
+ # print("Errors:\n", result.stderr)
14
+
15
+ # # Check if the command was successful
16
+ # if result.returncode == 0:
17
+ # print("Command executed successfully.")
18
+ # else:
19
+ # print("Command failed with return code:", result.returncode)
20
+
21
+ import gc
22
+ import math
23
+ import os
24
+ os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"
25
+ import shutil
26
+ import ffmpeg
27
+ import zipfile
28
+ import gradio as gr
29
+ import torch
30
+ import numpy as np
31
+ import matplotlib.pyplot as plt
32
+ from PIL import Image
33
+ from sam2.build_sam import build_sam2
34
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
35
+ from sam2.build_sam import build_sam2_video_predictor
36
+ import cv2
37
+
38
+ def clean(Seg_Tracker):
39
+ if Seg_Tracker is not None:
40
+ predictor, inference_state, image_predictor = Seg_Tracker
41
+ predictor.reset_state(inference_state)
42
+ del predictor
43
+ del inference_state
44
+ del image_predictor
45
+ del Seg_Tracker
46
+ gc.collect()
47
+ torch.cuda.empty_cache()
48
+ return None, ({}, {}), None, None, 0, None, None, None, 0
49
+
50
+ def get_meta_from_video(Seg_Tracker, input_video, scale_slider, checkpoint):
51
+
52
+ output_dir = 'output_frames'
53
+ output_masks_dir = 'output_masks'
54
+ output_combined_dir = 'output_combined'
55
+ clear_folder(output_dir)
56
+ clear_folder(output_masks_dir)
57
+ clear_folder(output_combined_dir)
58
+ if input_video is None:
59
+ return None, ({}, {}), None, None, 0, None, None, None, 0
60
+ cap = cv2.VideoCapture(input_video)
61
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
62
+ cap.release()
63
+ output_frames = int(total_frames * scale_slider)
64
+ frame_interval = max(1, total_frames // output_frames)
65
+ ffmpeg.input(input_video, hwaccel='cuda').output(
66
+ os.path.join(output_dir, '%07d.jpg'), q=2, start_number=0,
67
+ vf=rf'select=not(mod(n\,{frame_interval}))', vsync='vfr'
68
+ ).run()
69
+
70
+ first_frame_path = os.path.join(output_dir, '0000000.jpg')
71
+ first_frame = cv2.imread(first_frame_path)
72
+ first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
73
+
74
+ if Seg_Tracker is not None:
75
+ del Seg_Tracker
76
+ Seg_Tracker = None
77
+ gc.collect()
78
+ torch.cuda.empty_cache()
79
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
80
+ if torch.cuda.get_device_properties(0).major >= 8:
81
+ torch.backends.cuda.matmul.allow_tf32 = True
82
+ torch.backends.cudnn.allow_tf32 = True
83
+
84
+ if checkpoint == "tiny":
85
+ sam2_checkpoint = "segment-anything-2/checkpoints/sam2_hiera_tiny.pt"
86
+ model_cfg = "sam2_hiera_t.yaml"
87
+ elif checkpoint == "samll":
88
+ sam2_checkpoint = "segment-anything-2/checkpoints/sam2_hiera_small.pt"
89
+ model_cfg = "sam2_hiera_s.yaml"
90
+ elif checkpoint == "base-plus":
91
+ sam2_checkpoint = "segment-anything-2/checkpoints/sam2_hiera_base_plus.pt"
92
+ model_cfg = "sam2_hiera_b+.yaml"
93
+ elif checkpoint == "large":
94
+ sam2_checkpoint = "segment-anything-2/checkpoints/sam2_hiera_large.pt"
95
+ model_cfg = "sam2_hiera_l.yaml"
96
+
97
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda")
98
+ sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
99
+
100
+ image_predictor = SAM2ImagePredictor(sam2_model)
101
+ inference_state = predictor.init_state(video_path=output_dir)
102
+ predictor.reset_state(inference_state)
103
+ return (predictor, inference_state, image_predictor), ({}, {}), first_frame_rgb, first_frame_rgb, 0, None, None, None, 0
104
+
105
+ def mask2bbox(mask):
106
+ if len(np.where(mask > 0)[0]) == 0:
107
+ print(f'not mask')
108
+ return np.array([0, 0, 0, 0]).astype(np.int64), False
109
+ x_ = np.sum(mask, axis=0)
110
+ y_ = np.sum(mask, axis=1)
111
+ x0 = np.min(np.nonzero(x_)[0])
112
+ x1 = np.max(np.nonzero(x_)[0])
113
+ y0 = np.min(np.nonzero(y_)[0])
114
+ y1 = np.max(np.nonzero(y_)[0])
115
+ return np.array([x0, y0, x1, y1]).astype(np.int64), True
116
+
117
+ def sam_stroke(Seg_Tracker, drawing_board, last_draw, frame_num, ann_obj_id):
118
+ predictor, inference_state, image_predictor = Seg_Tracker
119
+ image_path = f'output_frames/{frame_num:07d}.jpg'
120
+ image = cv2.imread(image_path)
121
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
122
+ display_image = drawing_board["image"]
123
+ image_predictor.set_image(image)
124
+ input_mask = drawing_board["mask"]
125
+ input_mask[input_mask != 0] = 255
126
+ if last_draw is not None:
127
+ diff_mask = cv2.absdiff(input_mask, last_draw)
128
+ input_mask = diff_mask
129
+ bbox, hasMask = mask2bbox(input_mask[:, :, 0])
130
+ if not hasMask :
131
+ return Seg_Tracker, display_image, display_image
132
+ masks, scores, logits = image_predictor.predict( point_coords=None, point_labels=None, box=bbox[None, :], multimask_output=False,)
133
+ mask = masks > 0.0
134
+ masked_frame = show_mask(mask, display_image, ann_obj_id)
135
+ masked_with_rect = draw_rect(masked_frame, bbox, ann_obj_id)
136
+ frame_idx, object_ids, masks = predictor.add_new_mask(inference_state, frame_idx=frame_num, obj_id=ann_obj_id, mask=mask[0])
137
+ last_draw = drawing_board["mask"]
138
+ return Seg_Tracker, masked_with_rect, masked_with_rect, last_draw
139
+
140
+ def draw_rect(image, bbox, obj_id):
141
+ cmap = plt.get_cmap("tab10")
142
+ color = np.array(cmap(obj_id)[:3])
143
+ rgb_color = tuple(map(int, (color[:3] * 255).astype(np.uint8)))
144
+ inv_color = tuple(map(int, (255 - color[:3] * 255).astype(np.uint8)))
145
+ x0, y0, x1, y1 = bbox
146
+ image_with_rect = cv2.rectangle(image.copy(), (x0, y0), (x1, y1), inv_color, thickness=2)
147
+ return image_with_rect
148
+
149
+ def sam_click(Seg_Tracker, frame_num, point_mode, click_stack, ann_obj_id, evt: gr.SelectData):
150
+ points_dict, labels_dict = click_stack
151
+ predictor, inference_state, image_predictor = Seg_Tracker
152
+ ann_frame_idx = frame_num # the frame index we interact with
153
+ print(f'ann_frame_idx: {ann_frame_idx}')
154
+ point = np.array([[evt.index[0], evt.index[1]]], dtype=np.float32)
155
+ if point_mode == "Positive":
156
+ label = np.array([1], np.int32)
157
+ else:
158
+ label = np.array([0], np.int32)
159
+
160
+ if ann_frame_idx not in points_dict:
161
+ points_dict[ann_frame_idx] = {}
162
+ if ann_frame_idx not in labels_dict:
163
+ labels_dict[ann_frame_idx] = {}
164
+
165
+ if ann_obj_id not in points_dict[ann_frame_idx]:
166
+ points_dict[ann_frame_idx][ann_obj_id] = np.empty((0, 2), dtype=np.float32)
167
+ if ann_obj_id not in labels_dict[ann_frame_idx]:
168
+ labels_dict[ann_frame_idx][ann_obj_id] = np.empty((0,), dtype=np.int32)
169
+
170
+ points_dict[ann_frame_idx][ann_obj_id] = np.append(points_dict[ann_frame_idx][ann_obj_id], point, axis=0)
171
+ labels_dict[ann_frame_idx][ann_obj_id] = np.append(labels_dict[ann_frame_idx][ann_obj_id], label, axis=0)
172
+
173
+ click_stack = (points_dict, labels_dict)
174
+
175
+ frame_idx, out_obj_ids, out_mask_logits = predictor.add_new_points(
176
+ inference_state=inference_state,
177
+ frame_idx=ann_frame_idx,
178
+ obj_id=ann_obj_id,
179
+ points=points_dict[ann_frame_idx][ann_obj_id],
180
+ labels=labels_dict[ann_frame_idx][ann_obj_id],
181
+ )
182
+
183
+ image_path = f'output_frames/{ann_frame_idx:07d}.jpg'
184
+ image = cv2.imread(image_path)
185
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
186
+
187
+ masked_frame = image.copy()
188
+ for i, obj_id in enumerate(out_obj_ids):
189
+ mask = (out_mask_logits[i] > 0.0).cpu().numpy()
190
+ masked_frame = show_mask(mask, image=masked_frame, obj_id=obj_id)
191
+ masked_frame_with_markers = draw_markers(masked_frame, points_dict[ann_frame_idx], labels_dict[ann_frame_idx])
192
+
193
+ return Seg_Tracker, masked_frame_with_markers, masked_frame_with_markers, click_stack
194
+
195
+ def draw_markers(image, points_dict, labels_dict):
196
+ cmap = plt.get_cmap("tab10")
197
+ image_h, image_w = image.shape[:2]
198
+ marker_size = max(1, int(min(image_h, image_w) * 0.05))
199
+
200
+ for obj_id in points_dict:
201
+ color = np.array(cmap(obj_id)[:3])
202
+ rgb_color = tuple(map(int, (color[:3] * 255).astype(np.uint8)))
203
+ inv_color = tuple(map(int, (255 - color[:3] * 255).astype(np.uint8)))
204
+ for point, label in zip(points_dict[obj_id], labels_dict[obj_id]):
205
+ x, y = int(point[0]), int(point[1])
206
+ if label == 1:
207
+ cv2.drawMarker(image, (x, y), inv_color, markerType=cv2.MARKER_CROSS, markerSize=marker_size, thickness=2)
208
+ else:
209
+ cv2.drawMarker(image, (x, y), inv_color, markerType=cv2.MARKER_TILTED_CROSS, markerSize=int(marker_size / np.sqrt(2)), thickness=2)
210
+
211
+ return image
212
+
213
+ def show_mask(mask, image=None, obj_id=None):
214
+ cmap = plt.get_cmap("tab10")
215
+ cmap_idx = 0 if obj_id is None else obj_id
216
+ color = np.array([*cmap(cmap_idx)[:3], 0.6])
217
+
218
+ h, w = mask.shape[-2:]
219
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
220
+ mask_image = (mask_image * 255).astype(np.uint8)
221
+ if image is not None:
222
+ image_h, image_w = image.shape[:2]
223
+ if (image_h, image_w) != (h, w):
224
+ raise ValueError(f"Image dimensions ({image_h}, {image_w}) and mask dimensions ({h}, {w}) do not match")
225
+ colored_mask = np.zeros_like(image, dtype=np.uint8)
226
+ for c in range(3):
227
+ colored_mask[..., c] = mask_image[..., c]
228
+ alpha_mask = mask_image[..., 3] / 255.0
229
+ for c in range(3):
230
+ image[..., c] = np.where(alpha_mask > 0, (1 - alpha_mask) * image[..., c] + alpha_mask * colored_mask[..., c], image[..., c])
231
+ return image
232
+ return mask_image
233
+
234
+ def show_res_by_slider(frame_per, click_stack):
235
+ image_path = 'output_frames'
236
+ output_combined_dir = 'output_combined'
237
+
238
+ combined_frames = sorted([os.path.join(output_combined_dir, img_name) for img_name in os.listdir(output_combined_dir)])
239
+ if combined_frames:
240
+ output_masked_frame_path = combined_frames
241
+ else:
242
+ original_frames = sorted([os.path.join(image_path, img_name) for img_name in os.listdir(image_path)])
243
+ output_masked_frame_path = original_frames
244
+
245
+ total_frames_num = len(output_masked_frame_path)
246
+ if total_frames_num == 0:
247
+ print("No output results found")
248
+ return None, None
249
+ else:
250
+ frame_num = math.floor(total_frames_num * frame_per / 100)
251
+ if frame_per == 100:
252
+ frame_num = frame_num - 1
253
+ chosen_frame_path = output_masked_frame_path[frame_num]
254
+ print(f"{chosen_frame_path}")
255
+ chosen_frame_show = cv2.imread(chosen_frame_path)
256
+ chosen_frame_show = cv2.cvtColor(chosen_frame_show, cv2.COLOR_BGR2RGB)
257
+ points_dict, labels_dict = click_stack
258
+ if frame_num in points_dict and frame_num in labels_dict:
259
+ chosen_frame_show = draw_markers(chosen_frame_show, points_dict[frame_num], labels_dict[frame_num])
260
+ return chosen_frame_show, chosen_frame_show, frame_num
261
+
262
+ def clear_folder(folder_path):
263
+ if os.path.exists(folder_path):
264
+ shutil.rmtree(folder_path)
265
+ os.makedirs(folder_path)
266
+
267
+ def zip_folder(folder_path, output_zip_path):
268
+ with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_STORED) as zipf:
269
+ for root, _, files in os.walk(folder_path):
270
+ for file in files:
271
+ file_path = os.path.join(root, file)
272
+ zipf.write(file_path, os.path.relpath(file_path, folder_path))
273
+
274
+ def tracking_objects(Seg_Tracker, frame_num, input_video):
275
+ output_dir = 'output_frames'
276
+ output_masks_dir = 'output_masks'
277
+ output_combined_dir = 'output_combined'
278
+ output_video_path = 'output_video.mp4'
279
+ output_zip_path = 'output_masks.zip'
280
+ clear_folder(output_masks_dir)
281
+ clear_folder(output_combined_dir)
282
+ if os.path.exists(output_video_path):
283
+ os.remove(output_video_path)
284
+ if os.path.exists(output_zip_path):
285
+ os.remove(output_zip_path)
286
+ video_segments = {}
287
+ predictor, inference_state, image_predictor = Seg_Tracker
288
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
289
+ video_segments[out_frame_idx] = {
290
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
291
+ for i, out_obj_id in enumerate(out_obj_ids)
292
+ }
293
+ frame_files = sorted([f for f in os.listdir(output_dir) if f.endswith('.jpg')])
294
+ # for frame_idx in sorted(video_segments.keys()):
295
+ for frame_file in frame_files:
296
+ frame_idx = int(os.path.splitext(frame_file)[0])
297
+ frame_path = os.path.join(output_dir, frame_file)
298
+ image = cv2.imread(frame_path)
299
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
300
+ masked_frame = image.copy()
301
+ if frame_idx in video_segments:
302
+ for obj_id, mask in video_segments[frame_idx].items():
303
+ masked_frame = show_mask(mask, image=masked_frame, obj_id=obj_id)
304
+ mask_output_path = os.path.join(output_masks_dir, f'{obj_id}_{frame_idx:07d}.png')
305
+ cv2.imwrite(mask_output_path, show_mask(mask))
306
+ combined_output_path = os.path.join(output_combined_dir, f'{frame_idx:07d}.png')
307
+ combined_image_bgr = cv2.cvtColor(masked_frame, cv2.COLOR_RGB2BGR)
308
+ cv2.imwrite(combined_output_path, combined_image_bgr)
309
+ if frame_idx == frame_num:
310
+ final_masked_frame = masked_frame
311
+
312
+ cap = cv2.VideoCapture(input_video)
313
+ fps = cap.get(cv2.CAP_PROP_FPS)
314
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
315
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
316
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
317
+ cap.release()
318
+ # output_frames = int(total_frames * scale_slider)
319
+ output_frames = len([name for name in os.listdir(output_combined_dir) if os.path.isfile(os.path.join(output_combined_dir, name)) and name.endswith('.png')])
320
+ out_fps = fps * output_frames / total_frames
321
+ ffmpeg.input(os.path.join(output_combined_dir, '%07d.png'), framerate=out_fps).output(output_video_path, vcodec='h264_nvenc', pix_fmt='yuv420p').run()
322
+ # fourcc = cv2.VideoWriter_fourcc(*"mp4v")
323
+ # out = cv2.VideoWriter(output_video_path, fourcc, out_fps, (frame_width, frame_height))
324
+
325
+ # for i in range(output_frames):
326
+ # frame_path = os.path.join(output_combined_dir, f'{i:07d}.png')
327
+ # frame = cv2.imread(frame_path)
328
+ # out.write(frame)
329
+
330
+ # out.release()
331
+
332
+ zip_folder(output_masks_dir, output_zip_path)
333
+ print("done")
334
+ return final_masked_frame, final_masked_frame, output_video_path, output_video_path, output_zip_path
335
+
336
+ def increment_ann_obj_id(ann_obj_id):
337
+ ann_obj_id += 1
338
+ return ann_obj_id
339
+
340
+ def drawing_board_get_input_first_frame(input_first_frame):
341
+ return input_first_frame
342
+
343
+ def seg_track_app():
344
+
345
+ ##########################################################
346
+ ###################### Front-end ########################
347
+ ##########################################################
348
+ css = """
349
+ #input_output_video video {
350
+ max-height: 550px;
351
+ max-width: 100%;
352
+ height: auto;
353
+ }
354
+ """
355
+
356
+ app = gr.Blocks(css=css)
357
+
358
+ with app:
359
+ gr.Markdown(
360
+ '''
361
+ <div style="text-align:center;">
362
+ <span style="font-size:3em; font-weight:bold;">SAM2</span>
363
+ </div>
364
+ '''
365
+ )
366
+
367
+ click_stack = gr.State(({}, {}))
368
+ Seg_Tracker = gr.State(None)
369
+ frame_num = gr.State(value=(int(0)))
370
+ ann_obj_id = gr.State(value=(int(0)))
371
+ last_draw = gr.State(None)
372
+
373
+ with gr.Row():
374
+ with gr.Column(scale=0.5):
375
+ with gr.Row():
376
+ tab_video_input = gr.Tab(label="Video input")
377
+ with tab_video_input:
378
+ input_video = gr.Video(label='Input video', elem_id="input_output_video")
379
+ with gr.Row():
380
+ checkpoint = gr.Dropdown(label="Model Size", choices=["tiny", "small", "base-plus", "large"], value="tiny")
381
+ scale_slider = gr.Slider(
382
+ label="Downsampe Frame Rate",
383
+ minimum=0.0,
384
+ maximum=1.0,
385
+ step=0.25,
386
+ value=1.0,
387
+ interactive=True
388
+ )
389
+ preprocess_button = gr.Button(
390
+ value="Preprocess",
391
+ interactive=True,
392
+ )
393
+
394
+ with gr.Row():
395
+ tab_stroke = gr.Tab(label="Stroke to Box Prompt")
396
+ with tab_stroke:
397
+ drawing_board = gr.Image(label='Drawing Board', tool="sketch", brush_radius=10, interactive=True)
398
+ with gr.Row():
399
+ seg_acc_stroke = gr.Button(value="Segment", interactive=True)
400
+
401
+ tab_click = gr.Tab(label="Point Prompt")
402
+ with tab_click:
403
+ input_first_frame = gr.Image(label='Segment result of first frame',interactive=True).style(height=550)
404
+ with gr.Row():
405
+ point_mode = gr.Radio(
406
+ choices=["Positive", "Negative"],
407
+ value="Positive",
408
+ label="Point Prompt",
409
+ interactive=True)
410
+
411
+ with gr.Row():
412
+ with gr.Column():
413
+ frame_per = gr.Slider(
414
+ label = "Percentage of Frames Viewed",
415
+ minimum= 0.0,
416
+ maximum= 100.0,
417
+ step=0.01,
418
+ value=0.0,
419
+ )
420
+ new_object_button = gr.Button(
421
+ value="Add new object",
422
+ interactive=True
423
+ )
424
+ track_for_video = gr.Button(
425
+ value="Start Tracking",
426
+ interactive=True,
427
+ )
428
+ reset_button = gr.Button(
429
+ value="Reset",
430
+ interactive=True,
431
+ )
432
+
433
+ with gr.Column(scale=0.5):
434
+ output_video = gr.Video(label='Visualize Results', elem_id="input_output_video")
435
+ output_mp4 = gr.File(label="Predicted video")
436
+ output_mask = gr.File(label="Predicted masks")
437
+
438
+ # with gr.Tab(label='Video examples'):
439
+ # gr.Examples(
440
+ # label="",
441
+ # examples=[
442
+ # "assets/clip_012251_fps5_07_25.mp4",
443
+ # "assets/FLARE22_Tr_0004.mp4",
444
+ # "assets/FLARE22_Tr_0016.mp4",
445
+ # "assets/FLARE22_Tr_0046.mp4",
446
+ # "assets/c_elegans_mov_cut_fps12.mp4",
447
+ # "assets/12fps_Dylan_Burnette_neutrophil.mp4",
448
+ # "assets/12fps_Dancing_cells_trimmed.mp4",
449
+ # ],
450
+ # inputs=[input_video],
451
+ # )
452
+ # gr.Examples(
453
+ # label="",
454
+ # examples=[
455
+ # "assets/12fps_volvox_microcystis_play_trimmed.mp4",
456
+ # "assets/12fps_neuron_time_lapse.mp4",
457
+ # "assets/12fps_macrophages_phagocytosis.mp4",
458
+ # "assets/12fps_worm_eats_organism_3.mp4",
459
+ # "assets/12fps_worm_eats_organism_4.mp4",
460
+ # "assets/12fps_worm_eats_organism_5.mp4",
461
+ # "assets/12fps_worm_eats_organism_6.mp4",
462
+ # "assets/12fps_02_cups.mp4",
463
+ # ],
464
+ # inputs=[input_video],
465
+ # )
466
+ gr.Markdown(
467
+ '''
468
+ <div style="text-align:center; margin-top: 20px;">
469
+ The authors of this work highly appreciate Meta AI for making SAM2 publicly available to the community.
470
+ The interface was built on <a href="https://github.com/z-x-yang/Segment-and-Track-Anything/blob/main/tutorial/tutorial%20for%20WebUI-1.0-Version.md" target="_blank">SegTracker</a>.
471
+ <a href="https://docs.google.com/document/d/1idDBV0faOjdjVs-iAHr0uSrw_9_ZzLGrUI2FEdK-lso/edit?usp=sharing" target="_blank">Data source</a>
472
+ </div>
473
+ '''
474
+ )
475
+
476
+ ##########################################################
477
+ ###################### back-end #########################
478
+ ##########################################################
479
+
480
+ # listen to the preprocess button click to get the first frame of video with scaling
481
+ preprocess_button.click(
482
+ fn=get_meta_from_video,
483
+ inputs=[
484
+ Seg_Tracker,
485
+ input_video,
486
+ scale_slider,
487
+ checkpoint
488
+ ],
489
+ outputs=[
490
+ Seg_Tracker, click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id
491
+ ]
492
+ )
493
+
494
+ frame_per.release(
495
+ fn=show_res_by_slider,
496
+ inputs=[
497
+ frame_per, click_stack
498
+ ],
499
+ outputs=[
500
+ input_first_frame, drawing_board, frame_num
501
+ ]
502
+ )
503
+
504
+ # Interactively modify the mask acc click
505
+ input_first_frame.select(
506
+ fn=sam_click,
507
+ inputs=[
508
+ Seg_Tracker, frame_num, point_mode, click_stack, ann_obj_id
509
+ ],
510
+ outputs=[
511
+ Seg_Tracker, input_first_frame, drawing_board, click_stack
512
+ ]
513
+ )
514
+
515
+ # Track object in video
516
+ track_for_video.click(
517
+ fn=tracking_objects,
518
+ inputs=[
519
+ Seg_Tracker,
520
+ frame_num,
521
+ input_video,
522
+ ],
523
+ outputs=[
524
+ input_first_frame,
525
+ drawing_board,
526
+ output_video,
527
+ output_mp4,
528
+ output_mask
529
+ ]
530
+ )
531
+
532
+ reset_button.click(
533
+ fn=clean,
534
+ inputs=[
535
+ Seg_Tracker
536
+ ],
537
+ outputs=[
538
+ Seg_Tracker, click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id
539
+ ]
540
+ )
541
+
542
+ new_object_button.click(
543
+ fn=increment_ann_obj_id,
544
+ inputs=[
545
+ ann_obj_id
546
+ ],
547
+ outputs=[
548
+ ann_obj_id
549
+ ]
550
+ )
551
+
552
+ tab_stroke.select(
553
+ fn=drawing_board_get_input_first_frame,
554
+ inputs=[input_first_frame,],
555
+ outputs=[drawing_board,],
556
+ )
557
+
558
+ seg_acc_stroke.click(
559
+ fn=sam_stroke,
560
+ inputs=[
561
+ Seg_Tracker, drawing_board, last_draw, frame_num, ann_obj_id
562
+ ],
563
+ outputs=[
564
+ Seg_Tracker, input_first_frame, drawing_board, last_draw
565
+ ]
566
+ )
567
+
568
+ app.queue(concurrency_count=1)
569
+ app.launch(debug=True, enable_queue=True, share=True)
570
+
571
+ if __name__ == "__main__":
572
+ seg_track_app()