Update handler.py
Browse files- handler.py +48 -12
handler.py
CHANGED
@@ -1,20 +1,56 @@
|
|
1 |
from ultralytics import YOLO
|
2 |
-
|
3 |
|
4 |
-
class
|
5 |
-
def __init__(self,
|
6 |
-
|
7 |
-
|
|
|
|
|
8 |
|
9 |
def preprocess(self, inputs):
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
def
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
16 |
return results
|
17 |
|
18 |
def postprocess(self, outputs):
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from ultralytics import YOLO
|
2 |
+
import json
|
3 |
|
4 |
+
class YOLOHandler:
|
5 |
+
def __init__(self, model_path):
|
6 |
+
"""
|
7 |
+
Initialize the YOLO handler with the given model path.
|
8 |
+
"""
|
9 |
+
self.model = YOLO(model_path)
|
10 |
|
11 |
def preprocess(self, inputs):
|
12 |
+
"""
|
13 |
+
Preprocess the input data for YOLO inference.
|
14 |
+
:param inputs: JSON input data
|
15 |
+
:return: Preprocessed data
|
16 |
+
"""
|
17 |
+
# Assuming inputs is a dictionary with a key 'image' containing the image path or base64-encoded image
|
18 |
+
image_path = inputs.get("image")
|
19 |
+
if not image_path:
|
20 |
+
raise ValueError("Input must contain an 'image' key with a valid path.")
|
21 |
+
return image_path
|
22 |
|
23 |
+
def predict(self, inputs):
|
24 |
+
"""
|
25 |
+
Run inference on the YOLO model.
|
26 |
+
:param inputs: Preprocessed input data
|
27 |
+
:return: YOLO model results
|
28 |
+
"""
|
29 |
+
results = self.model(inputs) # Run the YOLO model
|
30 |
return results
|
31 |
|
32 |
def postprocess(self, outputs):
|
33 |
+
"""
|
34 |
+
Postprocess YOLO model results for returning JSON response.
|
35 |
+
:param outputs: Raw model output
|
36 |
+
:return: Processed results as a dictionary
|
37 |
+
"""
|
38 |
+
detections = []
|
39 |
+
for result in outputs:
|
40 |
+
for box in result.boxes:
|
41 |
+
detections.append({
|
42 |
+
"class": self.model.names[int(box.cls)], # Get class name
|
43 |
+
"confidence": box.conf.tolist(), # Confidence score
|
44 |
+
"box": box.xyxy.tolist() # Bounding box coordinates
|
45 |
+
})
|
46 |
+
return {"detections": detections}
|
47 |
+
|
48 |
+
def __call__(self, inputs):
|
49 |
+
"""
|
50 |
+
Full pipeline: preprocess, predict, postprocess.
|
51 |
+
:param inputs: JSON input data
|
52 |
+
:return: JSON output results
|
53 |
+
"""
|
54 |
+
preprocessed_data = self.preprocess(inputs)
|
55 |
+
predictions = self.predict(preprocessed_data)
|
56 |
+
return self.postprocess(predictions)
|