File size: 1,907 Bytes
cc048b4
6278f13
f9eb4e5
 
929b40e
f9eb4e5
929b40e
4f419ef
 
6278f13
cc048b4
929b40e
 
 
f9eb4e5
929b40e
 
 
f9eb4e5
929b40e
6278f13
929b40e
 
f9eb4e5
929b40e
f9eb4e5
929b40e
f9eb4e5
6278f13
 
929b40e
f9eb4e5
 
 
929b40e
 
 
 
 
f9eb4e5
929b40e
 
 
 
 
 
 
f9eb4e5
929b40e
f9eb4e5
929b40e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from ultralytics import YOLO

class EndpointHandler:
    def __init__(self, model_dir):
        """
        Initialize the EndpointHandler with the YOLO model from the given directory.
        """
        # Update the model file name to match your actual file
        self.model = YOLO(f"{model_dir}/tr_roi_finetune_130_large.pt")

    def preprocess(self, inputs):
        """
        Preprocess the input data for YOLO inference.
        :param inputs: JSON input data
        :return: Preprocessed data (image path)
        """
        image_path = inputs.get("image")
        if not image_path:
            raise ValueError("Input JSON must contain an 'image' key with a valid path.")
        return image_path

    def predict(self, inputs):
        """
        Run inference using the YOLO model.
        :param inputs: Preprocessed input data
        :return: Raw YOLO results
        """
        return self.model(inputs)

    def postprocess(self, outputs):
        """
        Postprocess YOLO results into a JSON-compatible format.
        :param outputs: Raw YOLO results
        :return: JSON results
        """
        detections = []
        for result in outputs:
            for box in result.boxes:
                detections.append({
                    "class": self.model.names[int(box.cls)],  # Class name
                    "confidence": box.conf.tolist(),         # Confidence score
                    "box": box.xyxy.tolist()                 # Bounding box coordinates
                })
        return {"detections": detections}

    def __call__(self, inputs):
        """
        Complete handler pipeline: preprocess -> predict -> postprocess.
        :param inputs: JSON input data
        :return: JSON output
        """
        preprocessed_data = self.preprocess(inputs)
        predictions = self.predict(preprocessed_data)
        return self.postprocess(predictions)