|
from time import perf_counter |
|
from typing import Any |
|
|
|
from ultralytics import YOLO |
|
|
|
from inference.core.cache import cache |
|
from inference.core.entities.requests.yolo_world import YOLOWorldInferenceRequest |
|
from inference.core.entities.responses.inference import ( |
|
InferenceResponseImage, |
|
ObjectDetectionInferenceResponse, |
|
ObjectDetectionPrediction, |
|
) |
|
from inference.core.models.defaults import DEFAULT_CONFIDENCE |
|
from inference.core.models.roboflow import RoboflowCoreModel |
|
from inference.core.utils.hash import get_string_list_hash |
|
from inference.core.utils.image_utils import load_image_rgb |
|
|
|
|
|
class YOLOWorld(RoboflowCoreModel): |
|
"""GroundingDINO class for zero-shot object detection. |
|
|
|
Attributes: |
|
model: The GroundingDINO model. |
|
""" |
|
|
|
def __init__(self, *args, model_id="yolo_world/l", **kwargs): |
|
"""Initializes the YOLO-World model. |
|
|
|
Args: |
|
*args: Variable length argument list. |
|
**kwargs: Arbitrary keyword arguments. |
|
""" |
|
|
|
super().__init__(*args, model_id=model_id, **kwargs) |
|
|
|
self.model = YOLO(self.cache_file("yolo-world.pt")) |
|
self.class_names = None |
|
|
|
def preproc_image(self, image: Any): |
|
"""Preprocesses an image. |
|
|
|
Args: |
|
image (InferenceRequestImage): The image to preprocess. |
|
|
|
Returns: |
|
np.array: The preprocessed image. |
|
""" |
|
np_image = load_image_rgb(image) |
|
return np_image[:, :, ::-1] |
|
|
|
def infer_from_request( |
|
self, |
|
request: YOLOWorldInferenceRequest, |
|
) -> ObjectDetectionInferenceResponse: |
|
""" |
|
Perform inference based on the details provided in the request, and return the associated responses. |
|
""" |
|
result = self.infer(**request.dict()) |
|
return result |
|
|
|
def infer( |
|
self, |
|
image: Any = None, |
|
text: list = None, |
|
confidence: float = DEFAULT_CONFIDENCE, |
|
**kwargs, |
|
): |
|
""" |
|
Run inference on a provided image. |
|
|
|
Args: |
|
request (CVInferenceRequest): The inference request. |
|
class_filter (Optional[List[str]]): A list of class names to filter, if provided. |
|
|
|
Returns: |
|
GroundingDINOInferenceRequest: The inference response. |
|
""" |
|
t1 = perf_counter() |
|
image = self.preproc_image(image) |
|
img_dims = image.shape |
|
|
|
if text is not None and text != self.class_names: |
|
self.set_classes(text) |
|
if self.class_names is None: |
|
raise ValueError( |
|
"Class names not set and not provided in the request. Must set class names before inference or provide them via the argument `text`." |
|
) |
|
results = self.model.predict( |
|
image, |
|
conf=confidence, |
|
verbose=False, |
|
)[0] |
|
|
|
t2 = perf_counter() - t1 |
|
|
|
predictions = [] |
|
for i, box in enumerate(results.boxes): |
|
x, y, w, h = box.xywh.tolist()[0] |
|
class_id = int(box.cls) |
|
predictions.append( |
|
ObjectDetectionPrediction( |
|
**{ |
|
"x": x, |
|
"y": y, |
|
"width": w, |
|
"height": h, |
|
"confidence": float(box.conf), |
|
"class": self.class_names[class_id], |
|
"class_id": class_id, |
|
} |
|
) |
|
) |
|
|
|
responses = ObjectDetectionInferenceResponse( |
|
predictions=predictions, |
|
image=InferenceResponseImage(width=img_dims[1], height=img_dims[0]), |
|
time=t2, |
|
) |
|
return responses |
|
|
|
def set_classes(self, text: list): |
|
"""Set the class names for the model. |
|
|
|
Args: |
|
text (list): The class names. |
|
""" |
|
text_hash = get_string_list_hash(text) |
|
cached_embeddings = cache.get_numpy(text_hash) |
|
if cached_embeddings is not None: |
|
self.model.model.txt_feats = cached_embeddings |
|
self.model.model.model[-1].nc = len(text) |
|
else: |
|
self.model.set_classes(text) |
|
cache.set_numpy(text_hash, self.model.model.txt_feats, expire=300) |
|
self.class_names = text |
|
|
|
def get_infer_bucket_file_list(self) -> list: |
|
"""Get the list of required files for inference. |
|
|
|
Returns: |
|
list: A list of required files for inference, e.g., ["model.pt"]. |
|
""" |
|
return ["yolo-world.pt"] |
|
|