Spaces:
Runtime error
Runtime error
from typing import Optional | |
import gradio as gr | |
import numpy as np | |
import spaces | |
import supervision as sv | |
import torch | |
from PIL import Image | |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator | |
from sam2.build_sam import build_sam2 | |
MARKDOWN = """ | |
# Segment Anything Model 2 🔥 | |
Segment Anything Model 2 (SAM 2) is a foundation model designed to address promptable | |
visual segmentation in both images and videos. The model extends its functionality to | |
video by treating images as single-frame videos. Its design, a simple transformer | |
architecture with streaming memory, enables real-time video processing. | |
""" | |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX) | |
CHECKPOINT = "checkpoints/sam2_hiera_large.pt" | |
CONFIG = "sam2_hiera_l.yaml" | |
sam2_model = build_sam2(CONFIG, CHECKPOINT, device=DEVICE, apply_postprocessing=False) | |
def process(image_input) -> Optional[Image.Image]: | |
mask_generator = SAM2AutomaticMaskGenerator(sam2_model) | |
image = np.array(image_input.convert("RGB")) | |
sam_result = mask_generator.generate(image) | |
detections = sv.Detections.from_sam(sam_result=sam_result) | |
return MASK_ANNOTATOR.annotate(scene=image_input, detections=detections) | |
with gr.Blocks() as demo: | |
gr.Markdown(MARKDOWN) | |
with gr.Row(): | |
with gr.Column(): | |
image_input_component = gr.Image(type='pil', label='Upload image') | |
submit_button_component = gr.Button(value='Submit', variant='primary') | |
with gr.Column(): | |
image_output_component = gr.Image(type='pil', label='Image Output') | |
submit_button_component.click( | |
fn=process, | |
inputs=[image_input_component], | |
outputs=[image_output_component] | |
) | |
demo.launch(debug=False, show_error=True, max_threads=1) | |