|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gc |
|
import math |
|
|
|
import torch.multiprocessing as mp |
|
import os |
|
from process_wrappers import clear_folder, draw_markers, sam_click_wrapper1, sam_stroke_process, tracking_objects_process |
|
os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1" |
|
import ffmpeg |
|
import cv2 |
|
|
|
|
|
def clean(): |
|
return ({}, {}, {}), None, None, 0, None, None, None, 0 |
|
|
|
def show_res_by_slider(frame_per, click_stack): |
|
image_path = '/tmp/output_frames' |
|
output_combined_dir = '/tmp/output_combined' |
|
|
|
combined_frames = sorted([os.path.join(output_combined_dir, img_name) for img_name in os.listdir(output_combined_dir)]) |
|
if combined_frames: |
|
output_masked_frame_path = combined_frames |
|
else: |
|
original_frames = sorted([os.path.join(image_path, img_name) for img_name in os.listdir(image_path)]) |
|
output_masked_frame_path = original_frames |
|
|
|
total_frames_num = len(output_masked_frame_path) |
|
if total_frames_num == 0: |
|
print("No output results found") |
|
return None, None |
|
else: |
|
frame_num = math.floor(total_frames_num * frame_per / 100) |
|
if frame_per == 100: |
|
frame_num = frame_num - 1 |
|
chosen_frame_path = output_masked_frame_path[frame_num] |
|
print(f"{chosen_frame_path}") |
|
chosen_frame_show = cv2.imread(chosen_frame_path) |
|
chosen_frame_show = cv2.cvtColor(chosen_frame_show, cv2.COLOR_BGR2RGB) |
|
points_dict, labels_dict, masks_dict = click_stack |
|
if frame_num in points_dict and frame_num in labels_dict: |
|
chosen_frame_show = draw_markers(chosen_frame_show, points_dict[frame_num], labels_dict[frame_num]) |
|
return chosen_frame_show, chosen_frame_show, frame_num |
|
|
|
def increment_ann_obj_id(ann_obj_id): |
|
ann_obj_id += 1 |
|
return ann_obj_id |
|
|
|
def drawing_board_get_input_first_frame(input_first_frame): |
|
return input_first_frame |
|
|
|
def sam_stroke_wrapper(click_stack, checkpoint, drawing_board, last_draw, frame_num, ann_obj_id): |
|
queue = mp.Queue() |
|
p = mp.Process(target=sam_stroke_process, args=(queue, click_stack, checkpoint, drawing_board, last_draw, frame_num, ann_obj_id)) |
|
p.start() |
|
error, result = queue.get() |
|
p.join() |
|
if error: |
|
raise Exception(f"Error in sam_stroke_process: {error}") |
|
return result |
|
|
|
def tracking_objects_wrapper(click_stack, checkpoint, frame_num, input_video): |
|
queue = mp.Queue() |
|
p = mp.Process(target=tracking_objects_process, args=(queue, click_stack, checkpoint, frame_num, input_video)) |
|
p.start() |
|
error, result = queue.get() |
|
p.join() |
|
if error: |
|
raise Exception(f"Error in sam_stroke_process: {error}") |
|
return result |
|
|
|
def seg_track_app(): |
|
import gradio as gr |
|
|
|
def sam_click_wrapper(checkpoint, frame_num, point_mode, click_stack, ann_obj_id, evt: gr.SelectData): |
|
return sam_click_wrapper1(checkpoint, frame_num, point_mode, click_stack, ann_obj_id, [evt.index[0], evt.index[1]]) |
|
|
|
def change_video(input_video): |
|
import gradio as gr |
|
if input_video is None: |
|
return 0, 0 |
|
cap = cv2.VideoCapture(input_video) |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
cap.release() |
|
scale_slider = gr.Slider.update(minimum=1.0, |
|
maximum=fps, |
|
step=1.0, |
|
value=fps,) |
|
frame_per = gr.Slider.update(minimum= 0.0, |
|
maximum= total_frames / fps, |
|
step=1.0/fps, |
|
value=0.0,) |
|
return scale_slider, frame_per |
|
|
|
def get_meta_from_video(input_video, scale_slider): |
|
import gradio as gr |
|
output_dir = '/tmp/output_frames' |
|
output_masks_dir = '/tmp/output_masks' |
|
output_combined_dir = '/tmp/`output_combined`' |
|
clear_folder(output_dir) |
|
clear_folder(output_masks_dir) |
|
clear_folder(output_combined_dir) |
|
if input_video is None: |
|
return ({}, {}, {}), None, None, 0, None, None, None, 0 |
|
cap = cv2.VideoCapture(input_video) |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
cap.release() |
|
frame_interval = max(1, int(fps // scale_slider)) |
|
print(f"frame_interval: {frame_interval}") |
|
try: |
|
ffmpeg.input(input_video, hwaccel='cuda').output( |
|
os.path.join(output_dir, '%07d.jpg'), q=2, start_number=0, |
|
vf=rf'select=not(mod(n\,{frame_interval}))', vsync='vfr' |
|
).run() |
|
except: |
|
print(f"ffmpeg cuda err") |
|
ffmpeg.input(input_video).output( |
|
os.path.join(output_dir, '%07d.jpg'), q=2, start_number=0, |
|
vf=rf'select=not(mod(n\,{frame_interval}))', vsync='vfr' |
|
).run() |
|
|
|
first_frame_path = os.path.join(output_dir, '0000000.jpg') |
|
first_frame = cv2.imread(first_frame_path) |
|
first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) |
|
|
|
frame_per = gr.Slider.update(minimum= 0.0, |
|
maximum= total_frames / fps, |
|
step=frame_interval / fps, |
|
value=0.0,) |
|
return ({}, {}, {}), first_frame_rgb, first_frame_rgb, frame_per, None, None, None, 0 |
|
|
|
|
|
|
|
|
|
css = """ |
|
#input_output_video video { |
|
max-height: 550px; |
|
max-width: 100%; |
|
height: auto; |
|
} |
|
""" |
|
|
|
app = gr.Blocks(css=css) |
|
|
|
with app: |
|
gr.Markdown( |
|
''' |
|
<div style="text-align:center; margin-bottom:20px;"> |
|
<span style="font-size:3em; font-weight:bold;">MedSAM2 for Video Segmentation 🔥</span> |
|
</div> |
|
<div style="text-align:center; margin-bottom:10px;"> |
|
<span style="font-size:1.5em; font-weight:bold;">MedSAM2-Segment Anything in Medical Images and Videos: Benchmark and Deployment</span> |
|
</div> |
|
<div style="text-align:center; margin-bottom:20px;"> |
|
<a href="https://github.com/bowang-lab/MedSAM/tree/MedSAM2"> |
|
<img src="https://badges.aleen42.com/src/github.svg" alt="GitHub" style="display:inline-block; margin-right:10px;"> |
|
</a> |
|
<a href="https://arxiv.org/abs/2408.03322"> |
|
<img src="https://img.shields.io/badge/arXiv-2408.03322-green?style=plastic" alt="Paper" style="display:inline-block; margin-right:10px;"> |
|
</a> |
|
<a href="https://github.com/bowang-lab/MedSAMSlicer/tree/SAM2"> |
|
<img src="https://img.shields.io/badge/3D-Slicer-Plugin" alt="3D Slicer Plugin" style="display:inline-block; margin-right:10px;"> |
|
</a> |
|
<a href="https://drive.google.com/drive/folders/1EXzRkxZmrXbahCFA8_ImFRM6wQDEpOSe?usp=sharing"> |
|
<img src="https://img.shields.io/badge/Video-Tutorial-green?style=plastic" alt="Video Tutorial" style="display:inline-block; margin-right:10px;"> |
|
</a> |
|
<a href="https://github.com/bowang-lab/MedSAM/tree/MedSAM2?tab=readme-ov-file#fine-tune-sam2-on-the-abdomen-ct-dataset"> |
|
<img src="https://img.shields.io/badge/Fine--tune-SAM2-blue" alt="Fine-tune SAM2" style="display:inline-block; margin-right:10px;"> |
|
</a> |
|
</div> |
|
<div style="text-align:left; margin-bottom:20px;"> |
|
This API supports using box (generated by scribble) and point prompts for video segmentation with |
|
<a href="https://ai.meta.com/sam2/" target="_blank">SAM2</a>. Welcome to join our <a href="https://forms.gle/hk4Efp6uWnhjUHFP6" target="_blank">mailing list</a> to get updates or send feedback. |
|
</div> |
|
<div style="margin-bottom:20px;"> |
|
<ol style="list-style:none; padding-left:0;"> |
|
<li>1. Upload video file</li> |
|
<li>2. Select model size and downsample frame rate and run <b>Preprocess</b></li> |
|
<li>3. Use <b>Stroke to Box Prompt</b> to draw box on the first frame or <b>Point Prompt</b> to click on the first frame.</li> |
|
<li> Note: The bounding rectangle of the stroke should be able to cover the segmentation target.</li> |
|
<li>4. Click <b>Segment</b> to get the segmentation result</li> |
|
<li>5. Click <b>Add New Object</b> to add new object</li> |
|
<li>6. Click <b>Start Tracking</b> to track objects in the video</li> |
|
<li>7. Click <b>Reset</b> to reset the app</li> |
|
<li>8. Download the video with segmentation results</li> |
|
</ol> |
|
</div> |
|
<div style="text-align:left; line-height:1.8;"> |
|
We designed this API and <a href="https://github.com/bowang-lab/MedSAMSlicer/tree/SAM2" target="_blank">3D Slicer Plugin</a> for medical image and video segmentation where the checkpoints are based on the original SAM2 models (<a href="https://github.com/facebookresearch/segment-anything-2" target="_blank">https://github.com/facebookresearch/segment-anything-2</a>). The image segmentation fine-tune code has been released on <a href="https://github.com/bowang-lab/MedSAM/tree/MedSAM2?tab=readme-ov-file#fine-tune-sam2-on-the-abdomen-ct-dataset" target="_blank">GitHub</a>. The video fine-tuning code is under active development and will be released as well. |
|
</div> |
|
<div style="text-align:left; line-height:1.8;"> |
|
If you find these tools useful, please consider citing the following papers: |
|
</div> |
|
<div style="text-align:left; line-height:1.8;"> |
|
Ravi, N., Gabeur, V., Hu, Y.T., Hu, R., Ryali, C., Ma, T., Khedr, H., Rädle, R., Rolland, C., Gustafson, L., Mintun, E., Pan, J., Alwala, K.V., Carion, N., Wu, C.Y., Girshick, R., Dollár, P., Feichtenhofer, C.: SAM 2: Segment Anything in Images and Videos. arXiv:2408.00714 (2024) |
|
</div> |
|
<div style="text-align:left; line-height:1.8;"> |
|
Ma, J., Kim, S., Li, F., Baharoon, M., Asakereh, R., Lyu, H., Wang, B.: Segment Anything in Medical Images and Videos: Benchmark and Deployment. arXiv preprint arXiv:2408.03322 (2024) |
|
</div> |
|
<div style="text-align:left; line-height:1.8;"> |
|
Other useful resources: |
|
<a href="https://ai.meta.com/sam2" target="_blank">Official demo</a> from MetaAI, |
|
<a href="https://www.youtube.com/watch?v=Dv003fTyO-Y" target="_blank">Video tutorial</a> from Piotr Skalski. |
|
</div> |
|
''' |
|
) |
|
|
|
click_stack = gr.State(({}, {}, {})) |
|
frame_num = gr.State(value=(int(0))) |
|
ann_obj_id = gr.State(value=(int(0))) |
|
last_draw = gr.State(None) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=0.5): |
|
with gr.Row(): |
|
tab_video_input = gr.Tab(label="Video input") |
|
with tab_video_input: |
|
input_video = gr.Video(label='Input video', elem_id="input_output_video") |
|
with gr.Row(): |
|
checkpoint = gr.Dropdown(label="Model Size", choices=["tiny", "small", "base-plus", "large"], value="tiny") |
|
scale_slider = gr.Slider( |
|
label="Downsampe Frame Rate (fps)", |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.25, |
|
value=1.0, |
|
interactive=True |
|
) |
|
preprocess_button = gr.Button( |
|
value="Preprocess", |
|
interactive=True, |
|
) |
|
|
|
with gr.Row(): |
|
tab_stroke = gr.Tab(label="Stroke to Box Prompt") |
|
with tab_stroke: |
|
drawing_board = gr.Image(label='Drawing Board', tool="sketch", brush_radius=10, interactive=True) |
|
with gr.Row(): |
|
seg_acc_stroke = gr.Button(value="Segment", interactive=True) |
|
|
|
tab_click = gr.Tab(label="Point Prompt") |
|
with tab_click: |
|
input_first_frame = gr.Image(label='Segment result of first frame',interactive=True, height=550) |
|
with gr.Row(): |
|
point_mode = gr.Radio( |
|
choices=["Positive", "Negative"], |
|
value="Positive", |
|
label="Point Prompt", |
|
interactive=True) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
frame_per = gr.Slider( |
|
label = "Time (seconds)", |
|
minimum= 0.0, |
|
maximum= 100.0, |
|
step=0.01, |
|
value=0.0, |
|
) |
|
new_object_button = gr.Button( |
|
value="Add New Object", |
|
interactive=True |
|
) |
|
track_for_video = gr.Button( |
|
value="Start Tracking", |
|
interactive=True, |
|
) |
|
reset_button = gr.Button( |
|
value="Reset", |
|
interactive=True, |
|
) |
|
|
|
with gr.Column(scale=0.5): |
|
output_video = gr.Video(label='Visualize Results', elem_id="input_output_video") |
|
output_mp4 = gr.File(label="Predicted video") |
|
output_mask = gr.File(label="Predicted masks") |
|
|
|
with gr.Tab(label='Video examples'): |
|
gr.Examples( |
|
label="", |
|
examples=[ |
|
"assets/12fps_Dancing_cells_trimmed.mp4", |
|
"assets/clip_012251_fps5_07_25.mp4", |
|
"assets/FLARE22_Tr_0004.mp4", |
|
"assets/c_elegans_mov_cut_fps12.mp4", |
|
], |
|
inputs=[input_video], |
|
) |
|
gr.Examples( |
|
label="", |
|
examples=[ |
|
"assets/12fps_volvox_microcystis_play_trimmed.mp4", |
|
"assets/12fps_macrophages_phagocytosis.mp4", |
|
"assets/12fps_worm_eats_organism_5.mp4", |
|
"assets/12fps_worm_eats_organism_6.mp4", |
|
"assets/12fps_02_cups.mp4", |
|
], |
|
inputs=[input_video], |
|
) |
|
gr.Markdown( |
|
''' |
|
<div style="text-align:center; margin-top: 20px;"> |
|
The authors of this work highly appreciate Meta AI for making SAM2 publicly available to the community. |
|
The interface was built on <a href="https://github.com/z-x-yang/Segment-and-Track-Anything/blob/main/tutorial/tutorial%20for%20WebUI-1.0-Version.md" target="_blank">SegTracker</a>, which is also an amazing tool for video segmentation tracking. |
|
<a href="https://docs.google.com/document/d/1idDBV0faOjdjVs-iAHr0uSrw_9_ZzLGrUI2FEdK-lso/edit?usp=sharing" target="_blank">Data source</a> |
|
</div> |
|
''' |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
preprocess_button.click( |
|
fn=get_meta_from_video, |
|
inputs=[ |
|
input_video, |
|
scale_slider, |
|
], |
|
outputs=[ |
|
click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id |
|
] |
|
) |
|
|
|
frame_per.release( |
|
fn= show_res_by_slider, |
|
inputs=[ |
|
frame_per, click_stack |
|
], |
|
outputs=[ |
|
input_first_frame, drawing_board, frame_num |
|
] |
|
) |
|
|
|
|
|
input_first_frame.select( |
|
fn=sam_click_wrapper, |
|
inputs=[ |
|
checkpoint, frame_num, point_mode, click_stack, ann_obj_id |
|
], |
|
outputs=[ |
|
input_first_frame, drawing_board, click_stack |
|
] |
|
) |
|
|
|
|
|
track_for_video.click( |
|
fn=tracking_objects_wrapper, |
|
inputs=[ |
|
click_stack, |
|
checkpoint, |
|
frame_num, |
|
input_video, |
|
], |
|
outputs=[ |
|
input_first_frame, |
|
drawing_board, |
|
output_video, |
|
output_mp4, |
|
output_mask |
|
] |
|
) |
|
|
|
reset_button.click( |
|
fn=clean, |
|
inputs=[], |
|
outputs=[ |
|
click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id |
|
] |
|
) |
|
|
|
new_object_button.click( |
|
fn=increment_ann_obj_id, |
|
inputs=[ |
|
ann_obj_id |
|
], |
|
outputs=[ |
|
ann_obj_id |
|
] |
|
) |
|
|
|
tab_stroke.select( |
|
fn=drawing_board_get_input_first_frame, |
|
inputs=[input_first_frame,], |
|
outputs=[drawing_board,], |
|
) |
|
|
|
seg_acc_stroke.click( |
|
fn=sam_stroke_wrapper, |
|
inputs=[ |
|
click_stack, checkpoint, drawing_board, last_draw, frame_num, ann_obj_id |
|
], |
|
outputs=[ |
|
click_stack, input_first_frame, drawing_board, last_draw |
|
] |
|
) |
|
|
|
input_video.change( |
|
fn=change_video, |
|
inputs=[input_video], |
|
outputs=[scale_slider, frame_per] |
|
) |
|
|
|
app.queue(concurrency_count=1) |
|
app.launch(debug=True, share=False) |
|
|
|
if __name__ == "__main__": |
|
mp.set_start_method('spawn', force=True) |
|
seg_track_app() |
|
|