spot-yolo-new / handler.py
ryanhuangtw's picture
Update handler.py
4f419ef verified
from ultralytics import YOLO
class EndpointHandler:
def __init__(self, model_dir):
"""
Initialize the EndpointHandler with the YOLO model from the given directory.
"""
# Update the model file name to match your actual file
self.model = YOLO(f"{model_dir}/tr_roi_finetune_130_large.pt")
def preprocess(self, inputs):
"""
Preprocess the input data for YOLO inference.
:param inputs: JSON input data
:return: Preprocessed data (image path)
"""
image_path = inputs.get("image")
if not image_path:
raise ValueError("Input JSON must contain an 'image' key with a valid path.")
return image_path
def predict(self, inputs):
"""
Run inference using the YOLO model.
:param inputs: Preprocessed input data
:return: Raw YOLO results
"""
return self.model(inputs)
def postprocess(self, outputs):
"""
Postprocess YOLO results into a JSON-compatible format.
:param outputs: Raw YOLO results
:return: JSON results
"""
detections = []
for result in outputs:
for box in result.boxes:
detections.append({
"class": self.model.names[int(box.cls)], # Class name
"confidence": box.conf.tolist(), # Confidence score
"box": box.xyxy.tolist() # Bounding box coordinates
})
return {"detections": detections}
def __call__(self, inputs):
"""
Complete handler pipeline: preprocess -> predict -> postprocess.
:param inputs: JSON input data
:return: JSON output
"""
preprocessed_data = self.preprocess(inputs)
predictions = self.predict(preprocessed_data)
return self.postprocess(predictions)