# import subprocess
# import re
# from typing import List, Tuple, Optional
# command = ["python", "setup.py", "build_ext", "--inplace"]
# result = subprocess.run(command, capture_output=True, text=True)
# print("Output:\n", result.stdout)
# print("Errors:\n", result.stderr)
# if result.returncode == 0:
# print("Command executed successfully.")
# else:
# print("Command failed with return code:", result.returncode)
import gc
import math
# import multiprocessing as mp
import torch.multiprocessing as mp
import os
from process_wrappers import clear_folder, draw_markers, sam_click_wrapper1, sam_stroke_process, tracking_objects_process
os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"
import ffmpeg
import cv2
def clean():
return ({}, {}, {}), None, None, 0, None, None, None, 0
def show_res_by_slider(frame_per, click_stack):
image_path = '/tmp/output_frames'
output_combined_dir = '/tmp/output_combined'
combined_frames = sorted([os.path.join(output_combined_dir, img_name) for img_name in os.listdir(output_combined_dir)])
if combined_frames:
output_masked_frame_path = combined_frames
else:
original_frames = sorted([os.path.join(image_path, img_name) for img_name in os.listdir(image_path)])
output_masked_frame_path = original_frames
total_frames_num = len(output_masked_frame_path)
if total_frames_num == 0:
print("No output results found")
return None, None
else:
frame_num = math.floor(total_frames_num * frame_per / 100)
if frame_per == 100:
frame_num = frame_num - 1
chosen_frame_path = output_masked_frame_path[frame_num]
print(f"{chosen_frame_path}")
chosen_frame_show = cv2.imread(chosen_frame_path)
chosen_frame_show = cv2.cvtColor(chosen_frame_show, cv2.COLOR_BGR2RGB)
points_dict, labels_dict, masks_dict = click_stack
if frame_num in points_dict and frame_num in labels_dict:
chosen_frame_show = draw_markers(chosen_frame_show, points_dict[frame_num], labels_dict[frame_num])
return chosen_frame_show, chosen_frame_show, frame_num
def increment_ann_obj_id(ann_obj_id):
ann_obj_id += 1
return ann_obj_id
def drawing_board_get_input_first_frame(input_first_frame):
return input_first_frame
def sam_stroke_wrapper(click_stack, checkpoint, drawing_board, last_draw, frame_num, ann_obj_id):
queue = mp.Queue()
p = mp.Process(target=sam_stroke_process, args=(queue, click_stack, checkpoint, drawing_board, last_draw, frame_num, ann_obj_id))
p.start()
error, result = queue.get()
p.join()
if error:
raise Exception(f"Error in sam_stroke_process: {error}")
return result
def tracking_objects_wrapper(click_stack, checkpoint, frame_num, input_video):
queue = mp.Queue()
p = mp.Process(target=tracking_objects_process, args=(queue, click_stack, checkpoint, frame_num, input_video))
p.start()
error, result = queue.get()
p.join()
if error:
raise Exception(f"Error in sam_stroke_process: {error}")
return result
def seg_track_app():
import gradio as gr
def sam_click_wrapper(checkpoint, frame_num, point_mode, click_stack, ann_obj_id, evt: gr.SelectData):
return sam_click_wrapper1(checkpoint, frame_num, point_mode, click_stack, ann_obj_id, [evt.index[0], evt.index[1]])
def change_video(input_video):
import gradio as gr
if input_video is None:
return 0, 0
cap = cv2.VideoCapture(input_video)
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
scale_slider = gr.Slider.update(minimum=1.0,
maximum=fps,
step=1.0,
value=fps,)
frame_per = gr.Slider.update(minimum= 0.0,
maximum= total_frames / fps,
step=1.0/fps,
value=0.0,)
return scale_slider, frame_per
def get_meta_from_video(input_video, scale_slider):
import gradio as gr
output_dir = '/tmp/output_frames'
output_masks_dir = '/tmp/output_masks'
output_combined_dir = '/tmp/`output_combined`'
clear_folder(output_dir)
clear_folder(output_masks_dir)
clear_folder(output_combined_dir)
if input_video is None:
return ({}, {}, {}), None, None, 0, None, None, None, 0
cap = cv2.VideoCapture(input_video)
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
frame_interval = max(1, int(fps // scale_slider))
print(f"frame_interval: {frame_interval}")
try:
ffmpeg.input(input_video, hwaccel='cuda').output(
os.path.join(output_dir, '%07d.jpg'), q=2, start_number=0,
vf=rf'select=not(mod(n\,{frame_interval}))', vsync='vfr'
).run()
except:
print(f"ffmpeg cuda err")
ffmpeg.input(input_video).output(
os.path.join(output_dir, '%07d.jpg'), q=2, start_number=0,
vf=rf'select=not(mod(n\,{frame_interval}))', vsync='vfr'
).run()
first_frame_path = os.path.join(output_dir, '0000000.jpg')
first_frame = cv2.imread(first_frame_path)
first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
frame_per = gr.Slider.update(minimum= 0.0,
maximum= total_frames / fps,
step=frame_interval / fps,
value=0.0,)
return ({}, {}, {}), first_frame_rgb, first_frame_rgb, frame_per, None, None, None, 0
##########################################################
###################### Front-end ########################
##########################################################
css = """
#input_output_video video {
max-height: 550px;
max-width: 100%;
height: auto;
}
"""
app = gr.Blocks(css=css)
with app:
gr.Markdown(
'''
MedSAM2 for Video Segmentation 🔥
MedSAM2-Segment Anything in Medical Images and Videos: Benchmark and Deployment
This API supports using box (generated by scribble) and point prompts for video segmentation with
SAM2. Welcome to join our
mailing list to get updates or send feedback.
- 1. Upload video file
- 2. Select model size and downsample frame rate and run Preprocess
- 3. Use Stroke to Box Prompt to draw box on the first frame or Point Prompt to click on the first frame.
- Note: The bounding rectangle of the stroke should be able to cover the segmentation target.
- 4. Click Segment to get the segmentation result
- 5. Click Add New Object to add new object
- 6. Click Start Tracking to track objects in the video
- 7. Click Reset to reset the app
- 8. Download the video with segmentation results
If you find these tools useful, please consider citing the following papers:
Ravi, N., Gabeur, V., Hu, Y.T., Hu, R., Ryali, C., Ma, T., Khedr, H., Rädle, R., Rolland, C., Gustafson, L., Mintun, E., Pan, J., Alwala, K.V., Carion, N., Wu, C.Y., Girshick, R., Dollár, P., Feichtenhofer, C.: SAM 2: Segment Anything in Images and Videos. arXiv:2408.00714 (2024)
Ma, J., Kim, S., Li, F., Baharoon, M., Asakereh, R., Lyu, H., Wang, B.: Segment Anything in Medical Images and Videos: Benchmark and Deployment. arXiv preprint arXiv:2408.03322 (2024)
'''
)
click_stack = gr.State(({}, {}, {}))
frame_num = gr.State(value=(int(0)))
ann_obj_id = gr.State(value=(int(0)))
last_draw = gr.State(None)
with gr.Row():
with gr.Column(scale=0.5):
with gr.Row():
tab_video_input = gr.Tab(label="Video input")
with tab_video_input:
input_video = gr.Video(label='Input video', elem_id="input_output_video")
with gr.Row():
checkpoint = gr.Dropdown(label="Model Size", choices=["tiny", "small", "base-plus", "large"], value="tiny")
scale_slider = gr.Slider(
label="Downsampe Frame Rate (fps)",
minimum=0.0,
maximum=1.0,
step=0.25,
value=1.0,
interactive=True
)
preprocess_button = gr.Button(
value="Preprocess",
interactive=True,
)
with gr.Row():
tab_stroke = gr.Tab(label="Stroke to Box Prompt")
with tab_stroke:
drawing_board = gr.Image(label='Drawing Board', tool="sketch", brush_radius=10, interactive=True)
with gr.Row():
seg_acc_stroke = gr.Button(value="Segment", interactive=True)
tab_click = gr.Tab(label="Point Prompt")
with tab_click:
input_first_frame = gr.Image(label='Segment result of first frame',interactive=True, height=550)
with gr.Row():
point_mode = gr.Radio(
choices=["Positive", "Negative"],
value="Positive",
label="Point Prompt",
interactive=True)
with gr.Row():
with gr.Column():
frame_per = gr.Slider(
label = "Time (seconds)",
minimum= 0.0,
maximum= 100.0,
step=0.01,
value=0.0,
)
new_object_button = gr.Button(
value="Add New Object",
interactive=True
)
track_for_video = gr.Button(
value="Start Tracking",
interactive=True,
)
reset_button = gr.Button(
value="Reset",
interactive=True,
)
with gr.Column(scale=0.5):
output_video = gr.Video(label='Visualize Results', elem_id="input_output_video")
output_mp4 = gr.File(label="Predicted video")
output_mask = gr.File(label="Predicted masks")
with gr.Tab(label='Video examples'):
gr.Examples(
label="",
examples=[
"assets/12fps_Dancing_cells_trimmed.mp4",
"assets/clip_012251_fps5_07_25.mp4",
"assets/FLARE22_Tr_0004.mp4",
"assets/c_elegans_mov_cut_fps12.mp4",
],
inputs=[input_video],
)
gr.Examples(
label="",
examples=[
"assets/12fps_volvox_microcystis_play_trimmed.mp4",
"assets/12fps_macrophages_phagocytosis.mp4",
"assets/12fps_worm_eats_organism_5.mp4",
"assets/12fps_worm_eats_organism_6.mp4",
"assets/12fps_02_cups.mp4",
],
inputs=[input_video],
)
gr.Markdown(
'''
The authors of this work highly appreciate Meta AI for making SAM2 publicly available to the community.
The interface was built on
SegTracker, which is also an amazing tool for video segmentation tracking.
Data source
'''
)
##########################################################
###################### back-end #########################
##########################################################
# listen to the preprocess button click to get the first frame of video with scaling
preprocess_button.click(
fn=get_meta_from_video,
inputs=[
input_video,
scale_slider,
],
outputs=[
click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id
]
)
frame_per.release(
fn= show_res_by_slider,
inputs=[
frame_per, click_stack
],
outputs=[
input_first_frame, drawing_board, frame_num
]
)
# Interactively modify the mask acc click
input_first_frame.select(
fn=sam_click_wrapper,
inputs=[
checkpoint, frame_num, point_mode, click_stack, ann_obj_id
],
outputs=[
input_first_frame, drawing_board, click_stack
]
)
# Track object in video
track_for_video.click(
fn=tracking_objects_wrapper,
inputs=[
click_stack,
checkpoint,
frame_num,
input_video,
],
outputs=[
input_first_frame,
drawing_board,
output_video,
output_mp4,
output_mask
]
)
reset_button.click(
fn=clean,
inputs=[],
outputs=[
click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id
]
)
new_object_button.click(
fn=increment_ann_obj_id,
inputs=[
ann_obj_id
],
outputs=[
ann_obj_id
]
)
tab_stroke.select(
fn=drawing_board_get_input_first_frame,
inputs=[input_first_frame,],
outputs=[drawing_board,],
)
seg_acc_stroke.click(
fn=sam_stroke_wrapper,
inputs=[
click_stack, checkpoint, drawing_board, last_draw, frame_num, ann_obj_id
],
outputs=[
click_stack, input_first_frame, drawing_board, last_draw
]
)
input_video.change(
fn=change_video,
inputs=[input_video],
outputs=[scale_slider, frame_per]
)
app.queue(concurrency_count=1)
app.launch(debug=True, share=False)
if __name__ == "__main__":
mp.set_start_method('spawn', force=True)
seg_track_app()