MedSAM2 / app.py
rexma's picture
Work in process
3d4535a
raw
history blame
19.3 kB
# import subprocess
# import re
# from typing import List, Tuple, Optional
# command = ["python", "setup.py", "build_ext", "--inplace"]
# result = subprocess.run(command, capture_output=True, text=True)
# print("Output:\n", result.stdout)
# print("Errors:\n", result.stderr)
# if result.returncode == 0:
# print("Command executed successfully.")
# else:
# print("Command failed with return code:", result.returncode)
import gc
import math
# import multiprocessing as mp
import torch.multiprocessing as mp
import os
from process_wrappers import clear_folder, draw_markers, sam_click_wrapper1, sam_stroke_process, tracking_objects_process
os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"
import ffmpeg
import cv2
def clean():
return ({}, {}, {}), None, None, 0, None, None, None, 0
def show_res_by_slider(frame_per, click_stack):
image_path = '/tmp/output_frames'
output_combined_dir = '/tmp/output_combined'
combined_frames = sorted([os.path.join(output_combined_dir, img_name) for img_name in os.listdir(output_combined_dir)])
if combined_frames:
output_masked_frame_path = combined_frames
else:
original_frames = sorted([os.path.join(image_path, img_name) for img_name in os.listdir(image_path)])
output_masked_frame_path = original_frames
total_frames_num = len(output_masked_frame_path)
if total_frames_num == 0:
print("No output results found")
return None, None
else:
frame_num = math.floor(total_frames_num * frame_per / 100)
if frame_per == 100:
frame_num = frame_num - 1
chosen_frame_path = output_masked_frame_path[frame_num]
print(f"{chosen_frame_path}")
chosen_frame_show = cv2.imread(chosen_frame_path)
chosen_frame_show = cv2.cvtColor(chosen_frame_show, cv2.COLOR_BGR2RGB)
points_dict, labels_dict, masks_dict = click_stack
if frame_num in points_dict and frame_num in labels_dict:
chosen_frame_show = draw_markers(chosen_frame_show, points_dict[frame_num], labels_dict[frame_num])
return chosen_frame_show, chosen_frame_show, frame_num
def increment_ann_obj_id(ann_obj_id):
ann_obj_id += 1
return ann_obj_id
def drawing_board_get_input_first_frame(input_first_frame):
return input_first_frame
def sam_stroke_wrapper(click_stack, checkpoint, drawing_board, last_draw, frame_num, ann_obj_id):
queue = mp.Queue()
p = mp.Process(target=sam_stroke_process, args=(queue, click_stack, checkpoint, drawing_board, last_draw, frame_num, ann_obj_id))
p.start()
error, result = queue.get()
p.join()
if error:
raise Exception(f"Error in sam_stroke_process: {error}")
return result
def tracking_objects_wrapper(click_stack, checkpoint, frame_num, input_video):
queue = mp.Queue()
p = mp.Process(target=tracking_objects_process, args=(queue, click_stack, checkpoint, frame_num, input_video))
p.start()
error, result = queue.get()
p.join()
if error:
raise Exception(f"Error in sam_stroke_process: {error}")
return result
def seg_track_app():
import gradio as gr
def sam_click_wrapper(checkpoint, frame_num, point_mode, click_stack, ann_obj_id, evt: gr.SelectData):
return sam_click_wrapper1(checkpoint, frame_num, point_mode, click_stack, ann_obj_id, [evt.index[0], evt.index[1]])
def change_video(input_video):
import gradio as gr
if input_video is None:
return 0, 0
cap = cv2.VideoCapture(input_video)
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
scale_slider = gr.Slider.update(minimum=1.0,
maximum=fps,
step=1.0,
value=fps,)
frame_per = gr.Slider.update(minimum= 0.0,
maximum= total_frames / fps,
step=1.0/fps,
value=0.0,)
return scale_slider, frame_per
def get_meta_from_video(input_video, scale_slider):
import gradio as gr
output_dir = '/tmp/output_frames'
output_masks_dir = '/tmp/output_masks'
output_combined_dir = '/tmp/`output_combined`'
clear_folder(output_dir)
clear_folder(output_masks_dir)
clear_folder(output_combined_dir)
if input_video is None:
return ({}, {}, {}), None, None, 0, None, None, None, 0
cap = cv2.VideoCapture(input_video)
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
frame_interval = max(1, int(fps // scale_slider))
print(f"frame_interval: {frame_interval}")
try:
ffmpeg.input(input_video, hwaccel='cuda').output(
os.path.join(output_dir, '%07d.jpg'), q=2, start_number=0,
vf=rf'select=not(mod(n\,{frame_interval}))', vsync='vfr'
).run()
except:
print(f"ffmpeg cuda err")
ffmpeg.input(input_video).output(
os.path.join(output_dir, '%07d.jpg'), q=2, start_number=0,
vf=rf'select=not(mod(n\,{frame_interval}))', vsync='vfr'
).run()
first_frame_path = os.path.join(output_dir, '0000000.jpg')
first_frame = cv2.imread(first_frame_path)
first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
frame_per = gr.Slider.update(minimum= 0.0,
maximum= total_frames / fps,
step=frame_interval / fps,
value=0.0,)
return ({}, {}, {}), first_frame_rgb, first_frame_rgb, frame_per, None, None, None, 0
##########################################################
###################### Front-end ########################
##########################################################
css = """
#input_output_video video {
max-height: 550px;
max-width: 100%;
height: auto;
}
"""
app = gr.Blocks(css=css)
with app:
gr.Markdown(
'''
<div style="text-align:center; margin-bottom:20px;">
<span style="font-size:3em; font-weight:bold;">MedSAM2 for Video Segmentation 🔥</span>
</div>
<div style="text-align:center; margin-bottom:10px;">
<span style="font-size:1.5em; font-weight:bold;">MedSAM2-Segment Anything in Medical Images and Videos: Benchmark and Deployment</span>
</div>
<div style="text-align:center; margin-bottom:20px;">
<a href="https://github.com/bowang-lab/MedSAM/tree/MedSAM2">
<img src="https://badges.aleen42.com/src/github.svg" alt="GitHub" style="display:inline-block; margin-right:10px;">
</a>
<a href="https://arxiv.org/abs/2408.03322">
<img src="https://img.shields.io/badge/arXiv-2408.03322-green?style=plastic" alt="Paper" style="display:inline-block; margin-right:10px;">
</a>
<a href="https://github.com/bowang-lab/MedSAMSlicer/tree/SAM2">
<img src="https://img.shields.io/badge/3D-Slicer-Plugin" alt="3D Slicer Plugin" style="display:inline-block; margin-right:10px;">
</a>
<a href="https://drive.google.com/drive/folders/1EXzRkxZmrXbahCFA8_ImFRM6wQDEpOSe?usp=sharing">
<img src="https://img.shields.io/badge/Video-Tutorial-green?style=plastic" alt="Video Tutorial" style="display:inline-block; margin-right:10px;">
</a>
<a href="https://github.com/bowang-lab/MedSAM/tree/MedSAM2?tab=readme-ov-file#fine-tune-sam2-on-the-abdomen-ct-dataset">
<img src="https://img.shields.io/badge/Fine--tune-SAM2-blue" alt="Fine-tune SAM2" style="display:inline-block; margin-right:10px;">
</a>
</div>
<div style="text-align:left; margin-bottom:20px;">
This API supports using box (generated by scribble) and point prompts for video segmentation with
<a href="https://ai.meta.com/sam2/" target="_blank">SAM2</a>. Welcome to join our <a href="https://forms.gle/hk4Efp6uWnhjUHFP6" target="_blank">mailing list</a> to get updates or send feedback.
</div>
<div style="margin-bottom:20px;">
<ol style="list-style:none; padding-left:0;">
<li>1. Upload video file</li>
<li>2. Select model size and downsample frame rate and run <b>Preprocess</b></li>
<li>3. Use <b>Stroke to Box Prompt</b> to draw box on the first frame or <b>Point Prompt</b> to click on the first frame.</li>
<li>&nbsp;&nbsp;&nbsp;Note: The bounding rectangle of the stroke should be able to cover the segmentation target.</li>
<li>4. Click <b>Segment</b> to get the segmentation result</li>
<li>5. Click <b>Add New Object</b> to add new object</li>
<li>6. Click <b>Start Tracking</b> to track objects in the video</li>
<li>7. Click <b>Reset</b> to reset the app</li>
<li>8. Download the video with segmentation results</li>
</ol>
</div>
<div style="text-align:left; line-height:1.8;">
We designed this API and <a href="https://github.com/bowang-lab/MedSAMSlicer/tree/SAM2" target="_blank">3D Slicer Plugin</a> for medical image and video segmentation where the checkpoints are based on the original SAM2 models (<a href="https://github.com/facebookresearch/segment-anything-2" target="_blank">https://github.com/facebookresearch/segment-anything-2</a>). The image segmentation fine-tune code has been released on <a href="https://github.com/bowang-lab/MedSAM/tree/MedSAM2?tab=readme-ov-file#fine-tune-sam2-on-the-abdomen-ct-dataset" target="_blank">GitHub</a>. The video fine-tuning code is under active development and will be released as well.
</div>
<div style="text-align:left; line-height:1.8;">
If you find these tools useful, please consider citing the following papers:
</div>
<div style="text-align:left; line-height:1.8;">
Ravi, N., Gabeur, V., Hu, Y.T., Hu, R., Ryali, C., Ma, T., Khedr, H., Rädle, R., Rolland, C., Gustafson, L., Mintun, E., Pan, J., Alwala, K.V., Carion, N., Wu, C.Y., Girshick, R., Dollár, P., Feichtenhofer, C.: SAM 2: Segment Anything in Images and Videos. arXiv:2408.00714 (2024)
</div>
<div style="text-align:left; line-height:1.8;">
Ma, J., Kim, S., Li, F., Baharoon, M., Asakereh, R., Lyu, H., Wang, B.: Segment Anything in Medical Images and Videos: Benchmark and Deployment. arXiv preprint arXiv:2408.03322 (2024)
</div>
<div style="text-align:left; line-height:1.8;">
Other useful resources:
<a href="https://ai.meta.com/sam2" target="_blank">Official demo</a> from MetaAI,
<a href="https://www.youtube.com/watch?v=Dv003fTyO-Y" target="_blank">Video tutorial</a> from Piotr Skalski.
</div>
'''
)
click_stack = gr.State(({}, {}, {}))
frame_num = gr.State(value=(int(0)))
ann_obj_id = gr.State(value=(int(0)))
last_draw = gr.State(None)
with gr.Row():
with gr.Column(scale=0.5):
with gr.Row():
tab_video_input = gr.Tab(label="Video input")
with tab_video_input:
input_video = gr.Video(label='Input video', elem_id="input_output_video")
with gr.Row():
checkpoint = gr.Dropdown(label="Model Size", choices=["tiny", "small", "base-plus", "large"], value="tiny")
scale_slider = gr.Slider(
label="Downsampe Frame Rate (fps)",
minimum=0.0,
maximum=1.0,
step=0.25,
value=1.0,
interactive=True
)
preprocess_button = gr.Button(
value="Preprocess",
interactive=True,
)
with gr.Row():
tab_stroke = gr.Tab(label="Stroke to Box Prompt")
with tab_stroke:
drawing_board = gr.Image(label='Drawing Board', tool="sketch", brush_radius=10, interactive=True)
with gr.Row():
seg_acc_stroke = gr.Button(value="Segment", interactive=True)
tab_click = gr.Tab(label="Point Prompt")
with tab_click:
input_first_frame = gr.Image(label='Segment result of first frame',interactive=True, height=550)
with gr.Row():
point_mode = gr.Radio(
choices=["Positive", "Negative"],
value="Positive",
label="Point Prompt",
interactive=True)
with gr.Row():
with gr.Column():
frame_per = gr.Slider(
label = "Time (seconds)",
minimum= 0.0,
maximum= 100.0,
step=0.01,
value=0.0,
)
new_object_button = gr.Button(
value="Add New Object",
interactive=True
)
track_for_video = gr.Button(
value="Start Tracking",
interactive=True,
)
reset_button = gr.Button(
value="Reset",
interactive=True,
)
with gr.Column(scale=0.5):
output_video = gr.Video(label='Visualize Results', elem_id="input_output_video")
output_mp4 = gr.File(label="Predicted video")
output_mask = gr.File(label="Predicted masks")
with gr.Tab(label='Video examples'):
gr.Examples(
label="",
examples=[
"assets/12fps_Dancing_cells_trimmed.mp4",
"assets/clip_012251_fps5_07_25.mp4",
"assets/FLARE22_Tr_0004.mp4",
"assets/c_elegans_mov_cut_fps12.mp4",
],
inputs=[input_video],
)
gr.Examples(
label="",
examples=[
"assets/12fps_volvox_microcystis_play_trimmed.mp4",
"assets/12fps_macrophages_phagocytosis.mp4",
"assets/12fps_worm_eats_organism_5.mp4",
"assets/12fps_worm_eats_organism_6.mp4",
"assets/12fps_02_cups.mp4",
],
inputs=[input_video],
)
gr.Markdown(
'''
<div style="text-align:center; margin-top: 20px;">
The authors of this work highly appreciate Meta AI for making SAM2 publicly available to the community.
The interface was built on <a href="https://github.com/z-x-yang/Segment-and-Track-Anything/blob/main/tutorial/tutorial%20for%20WebUI-1.0-Version.md" target="_blank">SegTracker</a>, which is also an amazing tool for video segmentation tracking.
<a href="https://docs.google.com/document/d/1idDBV0faOjdjVs-iAHr0uSrw_9_ZzLGrUI2FEdK-lso/edit?usp=sharing" target="_blank">Data source</a>
</div>
'''
)
##########################################################
###################### back-end #########################
##########################################################
# listen to the preprocess button click to get the first frame of video with scaling
preprocess_button.click(
fn=get_meta_from_video,
inputs=[
input_video,
scale_slider,
],
outputs=[
click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id
]
)
frame_per.release(
fn= show_res_by_slider,
inputs=[
frame_per, click_stack
],
outputs=[
input_first_frame, drawing_board, frame_num
]
)
# Interactively modify the mask acc click
input_first_frame.select(
fn=sam_click_wrapper,
inputs=[
checkpoint, frame_num, point_mode, click_stack, ann_obj_id
],
outputs=[
input_first_frame, drawing_board, click_stack
]
)
# Track object in video
track_for_video.click(
fn=tracking_objects_wrapper,
inputs=[
click_stack,
checkpoint,
frame_num,
input_video,
],
outputs=[
input_first_frame,
drawing_board,
output_video,
output_mp4,
output_mask
]
)
reset_button.click(
fn=clean,
inputs=[],
outputs=[
click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id
]
)
new_object_button.click(
fn=increment_ann_obj_id,
inputs=[
ann_obj_id
],
outputs=[
ann_obj_id
]
)
tab_stroke.select(
fn=drawing_board_get_input_first_frame,
inputs=[input_first_frame,],
outputs=[drawing_board,],
)
seg_acc_stroke.click(
fn=sam_stroke_wrapper,
inputs=[
click_stack, checkpoint, drawing_board, last_draw, frame_num, ann_obj_id
],
outputs=[
click_stack, input_first_frame, drawing_board, last_draw
]
)
input_video.change(
fn=change_video,
inputs=[input_video],
outputs=[scale_slider, frame_per]
)
app.queue(concurrency_count=1)
app.launch(debug=True, share=False)
if __name__ == "__main__":
mp.set_start_method('spawn', force=True)
seg_track_app()