# import subprocess # import re # from typing import List, Tuple, Optional # command = ["python", "setup.py", "build_ext", "--inplace"] # result = subprocess.run(command, capture_output=True, text=True) # print("Output:\n", result.stdout) # print("Errors:\n", result.stderr) # if result.returncode == 0: # print("Command executed successfully.") # else: # print("Command failed with return code:", result.returncode) import gc import math # import multiprocessing as mp import torch.multiprocessing as mp import os from process_wrappers import clear_folder, draw_markers, sam_click_wrapper1, sam_stroke_process, tracking_objects_process os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1" import ffmpeg import cv2 def clean(): return ({}, {}, {}), None, None, 0, None, None, None, 0 def show_res_by_slider(frame_per, click_stack): image_path = '/tmp/output_frames' output_combined_dir = '/tmp/output_combined' combined_frames = sorted([os.path.join(output_combined_dir, img_name) for img_name in os.listdir(output_combined_dir)]) if combined_frames: output_masked_frame_path = combined_frames else: original_frames = sorted([os.path.join(image_path, img_name) for img_name in os.listdir(image_path)]) output_masked_frame_path = original_frames total_frames_num = len(output_masked_frame_path) if total_frames_num == 0: print("No output results found") return None, None else: frame_num = math.floor(total_frames_num * frame_per / 100) if frame_per == 100: frame_num = frame_num - 1 chosen_frame_path = output_masked_frame_path[frame_num] print(f"{chosen_frame_path}") chosen_frame_show = cv2.imread(chosen_frame_path) chosen_frame_show = cv2.cvtColor(chosen_frame_show, cv2.COLOR_BGR2RGB) points_dict, labels_dict, masks_dict = click_stack if frame_num in points_dict and frame_num in labels_dict: chosen_frame_show = draw_markers(chosen_frame_show, points_dict[frame_num], labels_dict[frame_num]) return chosen_frame_show, chosen_frame_show, frame_num def increment_ann_obj_id(ann_obj_id): ann_obj_id += 1 return ann_obj_id def drawing_board_get_input_first_frame(input_first_frame): return input_first_frame def sam_stroke_wrapper(click_stack, checkpoint, drawing_board, last_draw, frame_num, ann_obj_id): queue = mp.Queue() p = mp.Process(target=sam_stroke_process, args=(queue, click_stack, checkpoint, drawing_board, last_draw, frame_num, ann_obj_id)) p.start() error, result = queue.get() p.join() if error: raise Exception(f"Error in sam_stroke_process: {error}") return result def tracking_objects_wrapper(click_stack, checkpoint, frame_num, input_video): queue = mp.Queue() p = mp.Process(target=tracking_objects_process, args=(queue, click_stack, checkpoint, frame_num, input_video)) p.start() error, result = queue.get() p.join() if error: raise Exception(f"Error in sam_stroke_process: {error}") return result def seg_track_app(): import gradio as gr def sam_click_wrapper(checkpoint, frame_num, point_mode, click_stack, ann_obj_id, evt: gr.SelectData): return sam_click_wrapper1(checkpoint, frame_num, point_mode, click_stack, ann_obj_id, [evt.index[0], evt.index[1]]) def change_video(input_video): import gradio as gr if input_video is None: return 0, 0 cap = cv2.VideoCapture(input_video) fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.release() scale_slider = gr.Slider.update(minimum=1.0, maximum=fps, step=1.0, value=fps,) frame_per = gr.Slider.update(minimum= 0.0, maximum= total_frames / fps, step=1.0/fps, value=0.0,) return scale_slider, frame_per def get_meta_from_video(input_video, scale_slider): import gradio as gr output_dir = '/tmp/output_frames' output_masks_dir = '/tmp/output_masks' output_combined_dir = '/tmp/`output_combined`' clear_folder(output_dir) clear_folder(output_masks_dir) clear_folder(output_combined_dir) if input_video is None: return ({}, {}, {}), None, None, 0, None, None, None, 0 cap = cv2.VideoCapture(input_video) fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.release() frame_interval = max(1, int(fps // scale_slider)) print(f"frame_interval: {frame_interval}") try: ffmpeg.input(input_video, hwaccel='cuda').output( os.path.join(output_dir, '%07d.jpg'), q=2, start_number=0, vf=rf'select=not(mod(n\,{frame_interval}))', vsync='vfr' ).run() except: print(f"ffmpeg cuda err") ffmpeg.input(input_video).output( os.path.join(output_dir, '%07d.jpg'), q=2, start_number=0, vf=rf'select=not(mod(n\,{frame_interval}))', vsync='vfr' ).run() first_frame_path = os.path.join(output_dir, '0000000.jpg') first_frame = cv2.imread(first_frame_path) first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) frame_per = gr.Slider.update(minimum= 0.0, maximum= total_frames / fps, step=frame_interval / fps, value=0.0,) return ({}, {}, {}), first_frame_rgb, first_frame_rgb, frame_per, None, None, None, 0 ########################################################## ###################### Front-end ######################## ########################################################## css = """ #input_output_video video { max-height: 550px; max-width: 100%; height: auto; } """ app = gr.Blocks(css=css) with app: gr.Markdown( '''
MedSAM2 for Video Segmentation 🔥
MedSAM2-Segment Anything in Medical Images and Videos: Benchmark and Deployment
GitHub Paper 3D Slicer Plugin Video Tutorial Fine-tune SAM2
This API supports using box (generated by scribble) and point prompts for video segmentation with SAM2. Welcome to join our mailing list to get updates or send feedback.
  1. 1. Upload video file
  2. 2. Select model size and downsample frame rate and run Preprocess
  3. 3. Use Stroke to Box Prompt to draw box on the first frame or Point Prompt to click on the first frame.
  4.    Note: The bounding rectangle of the stroke should be able to cover the segmentation target.
  5. 4. Click Segment to get the segmentation result
  6. 5. Click Add New Object to add new object
  7. 6. Click Start Tracking to track objects in the video
  8. 7. Click Reset to reset the app
  9. 8. Download the video with segmentation results
We designed this API and 3D Slicer Plugin for medical image and video segmentation where the checkpoints are based on the original SAM2 models (https://github.com/facebookresearch/segment-anything-2). The image segmentation fine-tune code has been released on GitHub. The video fine-tuning code is under active development and will be released as well.
If you find these tools useful, please consider citing the following papers:
Ravi, N., Gabeur, V., Hu, Y.T., Hu, R., Ryali, C., Ma, T., Khedr, H., Rädle, R., Rolland, C., Gustafson, L., Mintun, E., Pan, J., Alwala, K.V., Carion, N., Wu, C.Y., Girshick, R., Dollár, P., Feichtenhofer, C.: SAM 2: Segment Anything in Images and Videos. arXiv:2408.00714 (2024)
Ma, J., Kim, S., Li, F., Baharoon, M., Asakereh, R., Lyu, H., Wang, B.: Segment Anything in Medical Images and Videos: Benchmark and Deployment. arXiv preprint arXiv:2408.03322 (2024)
Other useful resources: Official demo from MetaAI, Video tutorial from Piotr Skalski.
''' ) click_stack = gr.State(({}, {}, {})) frame_num = gr.State(value=(int(0))) ann_obj_id = gr.State(value=(int(0))) last_draw = gr.State(None) with gr.Row(): with gr.Column(scale=0.5): with gr.Row(): tab_video_input = gr.Tab(label="Video input") with tab_video_input: input_video = gr.Video(label='Input video', elem_id="input_output_video") with gr.Row(): checkpoint = gr.Dropdown(label="Model Size", choices=["tiny", "small", "base-plus", "large"], value="tiny") scale_slider = gr.Slider( label="Downsampe Frame Rate (fps)", minimum=0.0, maximum=1.0, step=0.25, value=1.0, interactive=True ) preprocess_button = gr.Button( value="Preprocess", interactive=True, ) with gr.Row(): tab_stroke = gr.Tab(label="Stroke to Box Prompt") with tab_stroke: drawing_board = gr.Image(label='Drawing Board', tool="sketch", brush_radius=10, interactive=True) with gr.Row(): seg_acc_stroke = gr.Button(value="Segment", interactive=True) tab_click = gr.Tab(label="Point Prompt") with tab_click: input_first_frame = gr.Image(label='Segment result of first frame',interactive=True, height=550) with gr.Row(): point_mode = gr.Radio( choices=["Positive", "Negative"], value="Positive", label="Point Prompt", interactive=True) with gr.Row(): with gr.Column(): frame_per = gr.Slider( label = "Time (seconds)", minimum= 0.0, maximum= 100.0, step=0.01, value=0.0, ) new_object_button = gr.Button( value="Add New Object", interactive=True ) track_for_video = gr.Button( value="Start Tracking", interactive=True, ) reset_button = gr.Button( value="Reset", interactive=True, ) with gr.Column(scale=0.5): output_video = gr.Video(label='Visualize Results', elem_id="input_output_video") output_mp4 = gr.File(label="Predicted video") output_mask = gr.File(label="Predicted masks") with gr.Tab(label='Video examples'): gr.Examples( label="", examples=[ "assets/12fps_Dancing_cells_trimmed.mp4", "assets/clip_012251_fps5_07_25.mp4", "assets/FLARE22_Tr_0004.mp4", "assets/c_elegans_mov_cut_fps12.mp4", ], inputs=[input_video], ) gr.Examples( label="", examples=[ "assets/12fps_volvox_microcystis_play_trimmed.mp4", "assets/12fps_macrophages_phagocytosis.mp4", "assets/12fps_worm_eats_organism_5.mp4", "assets/12fps_worm_eats_organism_6.mp4", "assets/12fps_02_cups.mp4", ], inputs=[input_video], ) gr.Markdown( '''
The authors of this work highly appreciate Meta AI for making SAM2 publicly available to the community. The interface was built on SegTracker, which is also an amazing tool for video segmentation tracking. Data source
''' ) ########################################################## ###################### back-end ######################### ########################################################## # listen to the preprocess button click to get the first frame of video with scaling preprocess_button.click( fn=get_meta_from_video, inputs=[ input_video, scale_slider, ], outputs=[ click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id ] ) frame_per.release( fn= show_res_by_slider, inputs=[ frame_per, click_stack ], outputs=[ input_first_frame, drawing_board, frame_num ] ) # Interactively modify the mask acc click input_first_frame.select( fn=sam_click_wrapper, inputs=[ checkpoint, frame_num, point_mode, click_stack, ann_obj_id ], outputs=[ input_first_frame, drawing_board, click_stack ] ) # Track object in video track_for_video.click( fn=tracking_objects_wrapper, inputs=[ click_stack, checkpoint, frame_num, input_video, ], outputs=[ input_first_frame, drawing_board, output_video, output_mp4, output_mask ] ) reset_button.click( fn=clean, inputs=[], outputs=[ click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id ] ) new_object_button.click( fn=increment_ann_obj_id, inputs=[ ann_obj_id ], outputs=[ ann_obj_id ] ) tab_stroke.select( fn=drawing_board_get_input_first_frame, inputs=[input_first_frame,], outputs=[drawing_board,], ) seg_acc_stroke.click( fn=sam_stroke_wrapper, inputs=[ click_stack, checkpoint, drawing_board, last_draw, frame_num, ann_obj_id ], outputs=[ click_stack, input_first_frame, drawing_board, last_draw ] ) input_video.change( fn=change_video, inputs=[input_video], outputs=[scale_slider, frame_per] ) app.queue(concurrency_count=1) app.launch(debug=True, share=False) if __name__ == "__main__": mp.set_start_method('spawn', force=True) seg_track_app()