|
import gradio as gr |
|
import sahi.utils |
|
from sahi import AutoDetectionModel |
|
import sahi.predict |
|
import sahi.slicing |
|
from PIL import Image |
|
import numpy |
|
|
|
IMAGE_SIZE = 640 |
|
|
|
|
|
sahi.utils.file.download_from_url( |
|
"https://user-images.githubusercontent.com/34196005/142730935-2ace3999-a47b-49bb-83e0-2bdd509f1c90.jpg", |
|
"apple_tree.jpg", |
|
) |
|
|
|
|
|
|
|
|
|
model = AutoDetectionModel.from_pretrained( |
|
model_type="yolov5", model_path="yolov5s6.pt", device="cpu", confidence_threshold=0.5, image_size=IMAGE_SIZE |
|
) |
|
|
|
|
|
def sahi_yolo_inference( |
|
image, |
|
slice_height=512, |
|
slice_width=512, |
|
overlap_height_ratio=0.2, |
|
overlap_width_ratio=0.2, |
|
postprocess_type="NMS", |
|
postprocess_match_metric="IOU", |
|
postprocess_match_threshold=0.5, |
|
postprocess_class_agnostic=False, |
|
): |
|
|
|
image_width, image_height = image.size |
|
sliced_bboxes = sahi.slicing.get_slice_bboxes( |
|
image_height, |
|
image_width, |
|
slice_height, |
|
slice_width, |
|
False, |
|
overlap_height_ratio, |
|
overlap_width_ratio, |
|
) |
|
if len(sliced_bboxes) > 60: |
|
raise ValueError( |
|
f"{len(sliced_bboxes)} slices are too much for huggingface spaces, try smaller slice size." |
|
) |
|
|
|
|
|
prediction_result_1 = sahi.predict.get_prediction( |
|
image=image, detection_model=model |
|
) |
|
print(image) |
|
visual_result_1 = sahi.utils.cv.visualize_object_predictions( |
|
image=numpy.array(image), |
|
object_prediction_list=prediction_result_1.object_prediction_list, |
|
) |
|
output_1 = Image.fromarray(visual_result_1["image"]) |
|
|
|
|
|
prediction_result_2 = sahi.predict.get_sliced_prediction( |
|
image=image, |
|
detection_model=model, |
|
slice_height=int(slice_height), |
|
slice_width=int(slice_width), |
|
overlap_height_ratio=overlap_height_ratio, |
|
overlap_width_ratio=overlap_width_ratio, |
|
postprocess_type=postprocess_type, |
|
postprocess_match_metric=postprocess_match_metric, |
|
postprocess_match_threshold=postprocess_match_threshold, |
|
postprocess_class_agnostic=postprocess_class_agnostic, |
|
) |
|
visual_result_2 = sahi.utils.cv.visualize_object_predictions( |
|
image=numpy.array(image), |
|
object_prediction_list=prediction_result_2.object_prediction_list, |
|
) |
|
|
|
output_2 = Image.fromarray(visual_result_2["image"]) |
|
|
|
return output_1, output_2 |
|
|
|
|
|
inputs = [ |
|
gr.Image(type="pil", label="Original Image"), |
|
gr.Number(default=512, label="slice_height"), |
|
gr.Number(default=512, label="slice_width"), |
|
gr.Number(default=0.2, label="overlap_height_ratio"), |
|
gr.Number(default=0.2, label="overlap_width_ratio"), |
|
gr.Dropdown( |
|
["NMS", "GREEDYNMM"], |
|
type="value", |
|
value="NMS", |
|
label="postprocess_type", |
|
), |
|
gr.Dropdown( |
|
["IOU", "IOS"], type="value", default="IOU", label="postprocess_type" |
|
), |
|
gr.Number(default=0.5, label="postprocess_match_threshold"), |
|
gr.Checkbox(default=True, label="postprocess_class_agnostic"), |
|
] |
|
|
|
outputs = [ |
|
gr.Image(type="pil", label="YOLOv5s"), |
|
gr.Image(type="pil", label="YOLOv5s + SAHI"), |
|
] |
|
|
|
title = "Small Object Detection with SAHI + YOLOv5" |
|
description = "SAHI + YOLOv5 demo for small object detection. Upload an image or click an example image to use." |
|
article = "<p style='text-align: center'><a href='http://claireye.com.tw'>Claireye</a> | 2023</p>" |
|
examples = [ |
|
["apple_tree.jpg", 256, 256, 0.2, 0.2, "NMS", "IOU", 0.4, True] |
|
|
|
] |
|
|
|
gr.Interface( |
|
sahi_yolo_inference, |
|
inputs, |
|
outputs, |
|
title=title, |
|
description=description, |
|
article=article, |
|
examples=examples, |
|
theme="huggingface", |
|
cache_examples=True, |
|
).launch(debug=True, enable_queue=True) |
|
|