|
from ultralytics import YOLO |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir): |
|
""" |
|
Initialize the EndpointHandler with the YOLO model from the given directory. |
|
""" |
|
|
|
self.model = YOLO(f"{model_dir}/tr_roi_finetune_130_large.pt") |
|
|
|
def preprocess(self, inputs): |
|
""" |
|
Preprocess the input data for YOLO inference. |
|
:param inputs: JSON input data |
|
:return: Preprocessed data (image path) |
|
""" |
|
image_path = inputs.get("image") |
|
if not image_path: |
|
raise ValueError("Input JSON must contain an 'image' key with a valid path.") |
|
return image_path |
|
|
|
def predict(self, inputs): |
|
""" |
|
Run inference using the YOLO model. |
|
:param inputs: Preprocessed input data |
|
:return: Raw YOLO results |
|
""" |
|
return self.model(inputs) |
|
|
|
def postprocess(self, outputs): |
|
""" |
|
Postprocess YOLO results into a JSON-compatible format. |
|
:param outputs: Raw YOLO results |
|
:return: JSON results |
|
""" |
|
detections = [] |
|
for result in outputs: |
|
for box in result.boxes: |
|
detections.append({ |
|
"class": self.model.names[int(box.cls)], |
|
"confidence": box.conf.tolist(), |
|
"box": box.xyxy.tolist() |
|
}) |
|
return {"detections": detections} |
|
|
|
def __call__(self, inputs): |
|
""" |
|
Complete handler pipeline: preprocess -> predict -> postprocess. |
|
:param inputs: JSON input data |
|
:return: JSON output |
|
""" |
|
preprocessed_data = self.preprocess(inputs) |
|
predictions = self.predict(preprocessed_data) |
|
return self.postprocess(predictions) |
|
|