sahi-yolo11 / app.py
fcakyon's picture
update yo yolo11 (#4)
c37fa9e verified
import gradio as gr
import sahi.utils
from sahi import AutoDetectionModel
import sahi.predict
import sahi.slicing
from PIL import Image
import numpy
from ultralytics import YOLO
import sys
import types
if 'huggingface_hub.utils._errors' not in sys.modules:
mock_errors = types.ModuleType('_errors')
mock_errors.RepositoryNotFoundError = Exception
sys.modules['huggingface_hub.utils._errors'] = mock_errors
IMAGE_SIZE = 640
# Images
sahi.utils.file.download_from_url(
"https://user-images.githubusercontent.com/34196005/142730935-2ace3999-a47b-49bb-83e0-2bdd509f1c90.jpg",
"apple_tree.jpg",
)
sahi.utils.file.download_from_url(
"https://user-images.githubusercontent.com/34196005/142730936-1b397756-52e5-43be-a949-42ec0134d5d8.jpg",
"highway.jpg",
)
sahi.utils.file.download_from_url(
"https://user-images.githubusercontent.com/34196005/142742871-bf485f84-0355-43a3-be86-96b44e63c3a2.jpg",
"highway2.jpg",
)
sahi.utils.file.download_from_url(
"https://user-images.githubusercontent.com/34196005/142742872-1fefcc4d-d7e6-4c43-bbb7-6b5982f7e4ba.jpg",
"highway3.jpg",
)
# Model
model = AutoDetectionModel.from_pretrained(
model_type="ultralytics", model_path="yolo11s.pt", device="cpu", confidence_threshold=0.5, image_size=IMAGE_SIZE
)
def sahi_yolo_inference(
image,
slice_height=512,
slice_width=512,
overlap_height_ratio=0.2,
overlap_width_ratio=0.2,
postprocess_type="NMS",
postprocess_match_metric="IOU",
postprocess_match_threshold=0.5,
postprocess_class_agnostic=False,
):
image_width, image_height = image.size
sliced_bboxes = sahi.slicing.get_slice_bboxes(
image_height,
image_width,
slice_height,
slice_width,
False,
overlap_height_ratio,
overlap_width_ratio,
)
if len(sliced_bboxes) > 60:
raise ValueError(
f"{len(sliced_bboxes)} slices are too much for huggingface spaces, try smaller slice size."
)
# standard inference
prediction_result_1 = sahi.predict.get_prediction(
image=image, detection_model=model
)
print(image)
visual_result_1 = sahi.utils.cv.visualize_object_predictions(
image=numpy.array(image),
object_prediction_list=prediction_result_1.object_prediction_list,
)
output_1 = Image.fromarray(visual_result_1["image"])
# sliced inference
prediction_result_2 = sahi.predict.get_sliced_prediction(
image=image,
detection_model=model,
slice_height=int(slice_height),
slice_width=int(slice_width),
overlap_height_ratio=overlap_height_ratio,
overlap_width_ratio=overlap_width_ratio,
postprocess_type=postprocess_type,
postprocess_match_metric=postprocess_match_metric,
postprocess_match_threshold=postprocess_match_threshold,
postprocess_class_agnostic=postprocess_class_agnostic,
)
visual_result_2 = sahi.utils.cv.visualize_object_predictions(
image=numpy.array(image),
object_prediction_list=prediction_result_2.object_prediction_list,
)
output_2 = Image.fromarray(visual_result_2["image"])
return output_1, output_2
inputs = [
gr.Image(type="pil", label="Original Image"),
gr.Number(value=512, label="slice_height"),
gr.Number(value=512, label="slice_width"),
gr.Number(value=0.2, label="overlap_height_ratio"),
gr.Number(value=0.2, label="overlap_width_ratio"),
gr.Dropdown(
["NMS", "GREEDYNMM"],
type="value",
value="NMS",
label="postprocess_type",
),
gr.Dropdown(
["IOU", "IOS"], type="value", value="IOU", label="postprocess_type"
),
gr.Number(value=0.5, label="postprocess_match_threshold"),
gr.Checkbox(value=True, label="postprocess_class_agnostic"),
]
outputs = [
gr.Image(type="pil", label="YOLO11s Standard"),
gr.Image(type="pil", label="YOLO11s + SAHI Sliced"),
]
title = "Small Object Detection with SAHI + YOLO11"
description = "SAHI + YOLO11 demo for small object detection. Upload your own image or click an example image to use."
article = "<p style='text-align: center'>SAHI is a lightweight vision library for performing large scale object detection/ instance segmentation.. <a href='https://github.com/obss/sahi'>SAHI Github</a> | <a href='https://medium.com/codable/sahi-a-vision-library-for-performing-sliced-inference-on-large-images-small-objects-c8b086af3b80'>SAHI Blog</a> </p>"
examples = [
["apple_tree.jpg", 256, 256, 0.2, 0.2, "NMS", "IOU", 0.4, True],
["highway.jpg", 256, 256, 0.2, 0.2, "NMS", "IOU", 0.4, True],
["highway2.jpg", 512, 512, 0.2, 0.2, "NMS", "IOU", 0.4, True],
["highway3.jpg", 512, 512, 0.2, 0.2, "NMS", "IOU", 0.4, True],
]
gr.Interface(
sahi_yolo_inference,
inputs,
outputs,
title=title,
description=description,
article=article,
examples=examples,
theme="huggingface",
cache_examples=True,
).launch(debug=True)