import torch import numpy as np import gradio as gr from scipy.ndimage import binary_fill_holes from ultralytics import YOLOE from ultralytics.utils.torch_utils import smart_inference_mode from ultralytics.models.yolo.yoloe.predict_vp import YOLOEVPSegPredictor from gradio_image_prompter import ImagePrompter from huggingface_hub import hf_hub_download import spaces @spaces.GPU def init_model(model_id, is_pf=False): if not is_pf: path = hf_hub_download(repo_id="jameslahm/yoloe", filename=f"{model_id}-seg.pt") model = YOLOE(path) else: path = hf_hub_download(repo_id="jameslahm/yoloe", filename=f"{model_id}-seg-pf.pt") model = YOLOE(path) model.eval() model.to("cuda") return model @spaces.GPU @smart_inference_mode() def yoloe_inference(image, prompts, target_image, model_id, image_size, conf_thresh, iou_thresh, prompt_type): model = init_model(model_id) kwargs = {} if prompt_type == "Text": texts = prompts["texts"] model.set_classes(texts, model.get_text_pe(texts)) elif prompt_type == "Visual": kwargs = dict( prompts=prompts, predictor=YOLOEVPSegPredictor ) if target_image: model.predict(source=image, imgsz=image_size, conf=conf_thresh, iou=iou_thresh, return_vpe=True, **kwargs) model.set_classes(["object0"], model.predictor.vpe) model.predictor = None # unset VPPredictor image = target_image kwargs = {} elif prompt_type == "Prompt-free": vocab = model.get_vocab(prompts["texts"]) model = init_model(model_id, is_pf=True) model.set_vocab(vocab, names=prompts["texts"]) model.model.model[-1].is_fused = True model.model.model[-1].conf = 0.001 model.model.model[-1].max_det = 1000 results = model.predict(source=image, imgsz=image_size, conf=conf_thresh, iou=iou_thresh, **kwargs) annotated_image = results[0].plot() return annotated_image[:, :, ::-1] def app(): with gr.Blocks(): with gr.Row(): with gr.Column(): with gr.Row(): raw_image = gr.Image(type="pil", label="Image", visible=True, interactive=True) box_image = ImagePrompter(type="pil", label="DrawBox", visible=False, interactive=True) mask_image = gr.ImageEditor(type="pil", label="DrawMask", visible=False, interactive=True, layers=False, canvas_size=(640, 640)) target_image = gr.Image(type="pil", label="Target Image", visible=False, interactive=True) yoloe_infer = gr.Button(value="Detect & Segment Objects") prompt_type = gr.Textbox(value="Text", visible=False) with gr.Tab("Text") as text_tab: texts = gr.Textbox(label="Input Texts", value='person,bus', placeholder='person,bus', visible=True, interactive=True) with gr.Tab("Visual") as visual_tab: with gr.Row(): visual_prompt_type = gr.Dropdown(choices=["bboxes", "masks"], value="bboxes", label="Visual Type", interactive=True) visual_usage_type = gr.Radio(choices=["Intra-Image", "Inter-Image"], value="Intra-Image", label="Intra/Inter Image", interactive=True) with gr.Tab("Prompt-Free") as prompt_free_tab: gr.HTML( """

Prompt-Free Mode is On

""", show_label=False) model_id = gr.Dropdown( label="Model", choices=[ "yoloe-v8s", "yoloe-v8m", "yoloe-v8l", "yoloe-11s", "yoloe-11m", "yoloe-11l", ], value="yoloe-v8l", ) image_size = gr.Slider( label="Image Size", minimum=320, maximum=1280, step=32, value=640, ) conf_thresh = gr.Slider( label="Confidence Threshold", minimum=0.0, maximum=1.0, step=0.05, value=0.25, ) iou_thresh = gr.Slider( label="IoU Threshold", minimum=0.0, maximum=1.0, step=0.05, value=0.70, ) with gr.Column(): output_image = gr.Image(type="numpy", label="Annotated Image", visible=True) def update_text_image_visibility(): return gr.update(value="Text"), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) def update_visual_image_visiblity(visual_prompt_type, visual_usage_type): use_target = gr.update(visible=True) if visual_usage_type == "Inter-Image" else gr.update(visible=False) if visual_prompt_type == "bboxes": return gr.update(value="Visual"), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), use_target elif visual_prompt_type == "masks": return gr.update(value="Visual"), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), use_target def update_pf_image_visibility(): return gr.update(value="Prompt-free"), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) text_tab.select( fn=update_text_image_visibility, inputs=None, outputs=[prompt_type, raw_image, box_image, mask_image, target_image] ) visual_tab.select( fn=update_visual_image_visiblity, inputs=[visual_prompt_type, visual_usage_type], outputs=[prompt_type, raw_image, box_image, mask_image, target_image] ) prompt_free_tab.select( fn=update_pf_image_visibility, inputs=None, outputs=[prompt_type, raw_image, box_image, mask_image, target_image] ) def update_visual_prompt_type(visual_prompt_type): if visual_prompt_type == "bboxes": return gr.update(visible=True), gr.update(visible=False) if visual_prompt_type == "masks": return gr.update(visible=False), gr.update(visible=True) return gr.update(visible=False), gr.update(visible=False) def update_visual_usage_type(visual_usage_type): if visual_usage_type == "Intra-Image": return gr.update(visible=False, value=None) if visual_usage_type == "Inter-Image": return gr.update(visible=True, value=None) return gr.update(visible=False, value=None) visual_prompt_type.change( fn=update_visual_prompt_type, inputs=[visual_prompt_type], outputs=[box_image, mask_image] ) visual_usage_type.change( fn=update_visual_usage_type, inputs=[visual_usage_type], outputs=[target_image] ) def run_inference(raw_image, box_image, mask_image, target_image, texts, model_id, image_size, conf_thresh, iou_thresh, prompt_type, visual_prompt_type): # add text/built-in prompts if prompt_type == "Text" or prompt_type == "Prompt-free": image = raw_image if prompt_type == "Prompt-free": with open('tools/ram_tag_list.txt', 'r') as f: texts = [x.strip() for x in f.readlines()] else: texts = [text.strip() for text in texts.split(',')] prompts = { "texts": texts } # add visual prompt elif prompt_type == "Visual": if visual_prompt_type == "bboxes": image, points = box_image["image"], box_image["points"] points = np.array(points) prompts = { "bboxes": np.array([p[[0, 1, 3, 4]] for p in points if p[2] == 2]), } elif visual_prompt_type == "masks": image, masks = mask_image["background"], mask_image["layers"][0] # image = image.convert("RGB") masks = np.array(masks.convert("L")) masks = binary_fill_holes(masks).astype(np.uint8) masks[masks > 0] = 1 prompts = { "masks": masks[None] } return yoloe_inference(image, prompts, target_image, model_id, image_size, conf_thresh, iou_thresh, prompt_type) yoloe_infer.click( fn=run_inference, inputs=[raw_image, box_image, mask_image, target_image, texts, model_id, image_size, conf_thresh, iou_thresh, prompt_type, visual_prompt_type], outputs=[output_image], ) gradio_app = gr.Blocks() with gradio_app: gr.HTML( """

YOLOE: Real-Time Seeing Anything

""") gr.HTML( """

arXiv | github

""") gr.Markdown( """ We introduce **YOLOE(ye)**, a highly **efficient**, **unified**, and **open** object detection and segmentation model, like human eye, under different prompt mechanisms, like *texts*, *visual inputs*, and *prompt-free paradigm*. """ ) gr.Markdown( """ If desired objects are not identified, pleaset set a **smaller** confidence threshold, e.g., for visual prompts with handcrafted shape or cross-image prompts. """ ) gr.Markdown( """ Drawing **multiple** boxes or handcrafted shapes as visual prompt in an image is also supported. """ ) with gr.Row(): with gr.Column(): app() if __name__ == '__main__': gradio_app.launch(allowed_paths=["figures"])