yoloe / app.py
jameslahm's picture
Use supervision library to visualize predictions (#6)
621be96 verified
import numpy as np
import gradio as gr
import supervision as sv
from scipy.ndimage import binary_fill_holes
from ultralytics import YOLOE
from ultralytics.utils.torch_utils import smart_inference_mode
from ultralytics.models.yolo.yoloe.predict_vp import YOLOEVPSegPredictor
from gradio_image_prompter import ImagePrompter
from huggingface_hub import hf_hub_download
import spaces
@spaces.GPU
def init_model(model_id, is_pf=False):
filename = f"{model_id}-seg.pt" if not is_pf else f"{model_id}-seg-pf.pt"
path = hf_hub_download(repo_id="jameslahm/yoloe", filename=filename)
model = YOLOE(path)
model.eval()
model.to("cuda")
return model
@spaces.GPU
@smart_inference_mode()
def yoloe_inference(image, prompts, target_image, model_id, image_size, conf_thresh, iou_thresh, prompt_type):
model = init_model(model_id)
kwargs = {}
if prompt_type == "Text":
texts = prompts["texts"]
model.set_classes(texts, model.get_text_pe(texts))
elif prompt_type == "Visual":
kwargs = dict(
prompts=prompts,
predictor=YOLOEVPSegPredictor
)
if target_image:
model.predict(source=image, imgsz=image_size, conf=conf_thresh, iou=iou_thresh, return_vpe=True, **kwargs)
model.set_classes(["object0"], model.predictor.vpe)
model.predictor = None # unset VPPredictor
image = target_image
kwargs = {}
elif prompt_type == "Prompt-free":
vocab = model.get_vocab(prompts["texts"])
model = init_model(model_id, is_pf=True)
model.set_vocab(vocab, names=prompts["texts"])
model.model.model[-1].is_fused = True
model.model.model[-1].conf = 0.001
model.model.model[-1].max_det = 1000
results = model.predict(source=image, imgsz=image_size, conf=conf_thresh, iou=iou_thresh, **kwargs)
detections = sv.Detections.from_ultralytics(results[0])
resolution_wh = image.size
thickness = sv.calculate_optimal_line_thickness(resolution_wh=resolution_wh)
text_scale = sv.calculate_optimal_text_scale(resolution_wh=resolution_wh)
labels = [
f"{class_name} {confidence:.2f}"
for class_name, confidence
in zip(detections['class_name'], detections.confidence)
]
annotated_image = image.copy()
annotated_image = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX, opacity=0.4).annotate(
scene=annotated_image, detections=detections)
annotated_image = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=thickness).annotate(
scene=annotated_image, detections=detections)
annotated_image = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX, text_scale=text_scale, smart_position=True).annotate(
scene=annotated_image, detections=detections, labels=labels)
return annotated_image
def app():
with gr.Blocks():
with gr.Row():
with gr.Column():
with gr.Row():
raw_image = gr.Image(type="pil", label="Image", visible=True, interactive=True)
box_image = ImagePrompter(type="pil", label="DrawBox", visible=False, interactive=True)
mask_image = gr.ImageEditor(type="pil", label="DrawMask", visible=False, interactive=True, layers=False, canvas_size=(640, 640))
target_image = gr.Image(type="pil", label="Target Image", visible=False, interactive=True)
yoloe_infer = gr.Button(value="Detect & Segment Objects")
prompt_type = gr.Textbox(value="Text", visible=False)
with gr.Tab("Text") as text_tab:
texts = gr.Textbox(label="Input Texts", value='person,bus', placeholder='person,bus', visible=True, interactive=True)
with gr.Tab("Visual") as visual_tab:
with gr.Row():
visual_prompt_type = gr.Dropdown(choices=["bboxes", "masks"], value="bboxes", label="Visual Type", interactive=True)
visual_usage_type = gr.Radio(choices=["Intra-Image", "Cross-Image"], value="Intra-Image", label="Intra/Cross Image", interactive=True)
with gr.Tab("Prompt-Free") as prompt_free_tab:
gr.HTML(
"""
<p style='text-align: center'>
<b>Prompt-Free Mode is On</b>
</p>
""", show_label=False)
model_id = gr.Dropdown(
label="Model",
choices=[
"yoloe-v8s",
"yoloe-v8m",
"yoloe-v8l",
"yoloe-11s",
"yoloe-11m",
"yoloe-11l",
],
value="yoloe-v8l",
)
image_size = gr.Slider(
label="Image Size",
minimum=320,
maximum=1280,
step=32,
value=640,
)
conf_thresh = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.25,
)
iou_thresh = gr.Slider(
label="IoU Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.70,
)
with gr.Column():
output_image = gr.Image(type="numpy", label="Annotated Image", visible=True)
def update_text_image_visibility():
return gr.update(value="Text"), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
def update_visual_image_visiblity(visual_prompt_type, visual_usage_type):
if visual_prompt_type == "bboxes":
return gr.update(value="Visual"), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=(visual_usage_type == "Cross-Image"))
elif visual_prompt_type == "masks":
return gr.update(value="Visual"), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=(visual_usage_type == "Cross-Image"))
def update_pf_image_visibility():
return gr.update(value="Prompt-free"), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
text_tab.select(
fn=update_text_image_visibility,
inputs=None,
outputs=[prompt_type, raw_image, box_image, mask_image, target_image]
)
visual_tab.select(
fn=update_visual_image_visiblity,
inputs=[visual_prompt_type, visual_usage_type],
outputs=[prompt_type, raw_image, box_image, mask_image, target_image]
)
prompt_free_tab.select(
fn=update_pf_image_visibility,
inputs=None,
outputs=[prompt_type, raw_image, box_image, mask_image, target_image]
)
def update_visual_prompt_type(visual_prompt_type):
if visual_prompt_type == "bboxes":
return gr.update(visible=True), gr.update(visible=False)
if visual_prompt_type == "masks":
return gr.update(visible=False), gr.update(visible=True)
return gr.update(visible=False), gr.update(visible=False)
def update_visual_usage_type(visual_usage_type):
if visual_usage_type == "Intra-Image":
return gr.update(visible=False)
if visual_usage_type == "Cross-Image":
return gr.update(visible=True)
return gr.update(visible=False)
visual_prompt_type.change(
fn=update_visual_prompt_type,
inputs=[visual_prompt_type],
outputs=[box_image, mask_image]
)
visual_usage_type.change(
fn=update_visual_usage_type,
inputs=[visual_usage_type],
outputs=[target_image]
)
def run_inference(raw_image, box_image, mask_image, target_image, texts, model_id, image_size, conf_thresh, iou_thresh, prompt_type, visual_prompt_type, visual_usage_type):
# add text/built-in prompts
if prompt_type == "Text" or prompt_type == "Prompt-free":
target_image = None
image = raw_image
if prompt_type == "Prompt-free":
with open('tools/ram_tag_list.txt', 'r') as f:
texts = [x.strip() for x in f.readlines()]
else:
texts = [text.strip() for text in texts.split(',')]
prompts = {
"texts": texts
}
# add visual prompt
elif prompt_type == "Visual":
if visual_usage_type != "Cross-Image":
target_image = None
if visual_prompt_type == "bboxes":
image, points = box_image["image"], box_image["points"]
points = np.array(points)
if len(points) == 0:
gr.Warning("No boxes are provided. No image output.", visible=True)
return gr.update(value=None)
bboxes = np.array([p[[0, 1, 3, 4]] for p in points if p[2] == 2])
prompts = {
"bboxes": bboxes,
"cls": np.array([0] * len(bboxes))
}
elif visual_prompt_type == "masks":
image, masks = mask_image["background"], mask_image["layers"][0]
# image = image.convert("RGB")
masks = np.array(masks.convert("L"))
masks = binary_fill_holes(masks).astype(np.uint8)
masks[masks > 0] = 1
if masks.sum() == 0:
gr.Warning("No masks are provided. No image output.", visible=True)
return gr.update(value=None)
prompts = {
"masks": masks[None],
"cls": np.array([0])
}
return yoloe_inference(image, prompts, target_image, model_id, image_size, conf_thresh, iou_thresh, prompt_type)
yoloe_infer.click(
fn=run_inference,
inputs=[raw_image, box_image, mask_image, target_image, texts, model_id, image_size, conf_thresh, iou_thresh, prompt_type, visual_prompt_type, visual_usage_type],
outputs=[output_image],
)
###################### Examples ##########################
text_examples = gr.Examples(
examples=[[
"ultralytics/assets/bus.jpg",
"person,bus",
"yoloe-v8l",
640,
0.25,
0.7]],
inputs=[raw_image, texts, model_id, image_size, conf_thresh, iou_thresh],
visible=True, cache_examples=False, label="Text Prompt Examples")
box_examples = gr.Examples(
examples=[[
{"image": "ultralytics/assets/bus_box.jpg", "points": [[235, 408, 2, 342, 863, 3]]},
"ultralytics/assets/zidane.jpg",
"yoloe-v8l",
640,
0.2,
0.7,
]],
inputs=[box_image, target_image, model_id, image_size, conf_thresh, iou_thresh],
visible=False, cache_examples=False, label="Box Visual Prompt Examples")
mask_examples = gr.Examples(
examples=[[
{"background": "ultralytics/assets/bus.jpg", "layers": ["ultralytics/assets/bus_mask.png"], "composite": "ultralytics/assets/bus_composite.jpg"},
"ultralytics/assets/zidane.jpg",
"yoloe-v8l",
640,
0.15,
0.7,
]],
inputs=[mask_image, target_image, model_id, image_size, conf_thresh, iou_thresh],
visible=False, cache_examples=False, label="Mask Visual Prompt Examples")
pf_examples = gr.Examples(
examples=[[
"ultralytics/assets/bus.jpg",
"yoloe-v8l",
640,
0.25,
0.7,
]],
inputs=[raw_image, model_id, image_size, conf_thresh, iou_thresh],
visible=False, cache_examples=False, label="Prompt-free Examples")
# Components update
def load_box_example(visual_usage_type):
return (gr.update(visible=True, value={"image": "ultralytics/assets/bus_box.jpg", "points": [[235, 408, 2, 342, 863, 3]]}),
gr.update(visible=(visual_usage_type=="Cross-Image")))
def load_mask_example(visual_usage_type):
return gr.update(visible=True), gr.update(visible=(visual_usage_type=="Cross-Image"))
box_examples.load_input_event.then(
fn=load_box_example,
inputs=visual_usage_type,
outputs=[box_image, target_image]
)
mask_examples.load_input_event.then(
fn=load_mask_example,
inputs=visual_usage_type,
outputs=[mask_image, target_image]
)
# Examples update
def update_text_examples():
return gr.Dataset(visible=True), gr.Dataset(visible=False), gr.Dataset(visible=False), gr.Dataset(visible=False)
def update_pf_examples():
return gr.Dataset(visible=False), gr.Dataset(visible=False), gr.Dataset(visible=False), gr.Dataset(visible=True)
def update_visual_examples(visual_prompt_type):
if visual_prompt_type == "bboxes":
return gr.Dataset(visible=False), gr.Dataset(visible=True), gr.Dataset(visible=False), gr.Dataset(visible=False),
elif visual_prompt_type == "masks":
return gr.Dataset(visible=False), gr.Dataset(visible=False), gr.Dataset(visible=True), gr.Dataset(visible=False),
text_tab.select(
fn=update_text_examples,
inputs=None,
outputs=[text_examples.dataset, box_examples.dataset, mask_examples.dataset, pf_examples.dataset]
)
visual_tab.select(
fn=update_visual_examples,
inputs=[visual_prompt_type],
outputs=[text_examples.dataset, box_examples.dataset, mask_examples.dataset, pf_examples.dataset]
)
prompt_free_tab.select(
fn=update_pf_examples,
inputs=None,
outputs=[text_examples.dataset, box_examples.dataset, mask_examples.dataset, pf_examples.dataset]
)
visual_prompt_type.change(
fn=update_visual_examples,
inputs=[visual_prompt_type],
outputs=[text_examples.dataset, box_examples.dataset, mask_examples.dataset, pf_examples.dataset]
)
visual_usage_type.change(
fn=update_visual_examples,
inputs=[visual_prompt_type],
outputs=[text_examples.dataset, box_examples.dataset, mask_examples.dataset, pf_examples.dataset]
)
gradio_app = gr.Blocks()
with gradio_app:
gr.HTML(
"""
<h1 style='text-align: center'>
<img src="/file=figures/logo.png" width="2.5%" style="display:inline;padding-bottom:4px">
YOLOE: Real-Time Seeing Anything
</h1>
""")
gr.HTML(
"""
<h3 style='text-align: center'>
<a href='https://arxiv.org/abs/2503.07465' target='_blank'>arXiv</a> | <a href='https://github.com/THU-MIG/yoloe' target='_blank'>github</a>
</h3>
""")
gr.Markdown(
"""
We introduce **YOLOE(ye)**, a highly **efficient**, **unified**, and **open** object detection and segmentation model, like human eye, under different prompt mechanisms, like *texts*, *visual inputs*, and *prompt-free paradigm*.
"""
)
gr.Markdown(
"""
If desired objects are not identified, pleaset set a **smaller** confidence threshold, e.g., for visual prompts with handcrafted shape or cross-image prompts.
"""
)
gr.Markdown(
"""
Drawing **multiple** boxes or handcrafted shapes as visual prompt in an image is also supported, which leads to more accurate prompt.
"""
)
with gr.Row():
with gr.Column():
app()
if __name__ == '__main__':
gradio_app.launch(allowed_paths=["figures"])