ryanhuangtw commited on
Commit
929b40e
·
verified ·
1 Parent(s): 1280207

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +48 -12
handler.py CHANGED
@@ -1,20 +1,56 @@
1
  from ultralytics import YOLO
2
- from huggingface_inference_toolkit.handler import BaseHandler
3
 
4
- class CustomYOLOHandler(BaseHandler):
5
- def __init__(self, model_dir, *args, **kwargs):
6
- super().__init__(*args, **kwargs)
7
- self.model = YOLO(f"{model_dir}/model.pt")
 
 
8
 
9
  def preprocess(self, inputs):
10
- # Preprocess inputs for YOLO
11
- return inputs
 
 
 
 
 
 
 
 
12
 
13
- def inference(self, inputs):
14
- # Perform inference
15
- results = self.model(inputs)
 
 
 
 
16
  return results
17
 
18
  def postprocess(self, outputs):
19
- # Convert YOLO results to HuggingFace pipeline outputs
20
- return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from ultralytics import YOLO
2
+ import json
3
 
4
+ class YOLOHandler:
5
+ def __init__(self, model_path):
6
+ """
7
+ Initialize the YOLO handler with the given model path.
8
+ """
9
+ self.model = YOLO(model_path)
10
 
11
  def preprocess(self, inputs):
12
+ """
13
+ Preprocess the input data for YOLO inference.
14
+ :param inputs: JSON input data
15
+ :return: Preprocessed data
16
+ """
17
+ # Assuming inputs is a dictionary with a key 'image' containing the image path or base64-encoded image
18
+ image_path = inputs.get("image")
19
+ if not image_path:
20
+ raise ValueError("Input must contain an 'image' key with a valid path.")
21
+ return image_path
22
 
23
+ def predict(self, inputs):
24
+ """
25
+ Run inference on the YOLO model.
26
+ :param inputs: Preprocessed input data
27
+ :return: YOLO model results
28
+ """
29
+ results = self.model(inputs) # Run the YOLO model
30
  return results
31
 
32
  def postprocess(self, outputs):
33
+ """
34
+ Postprocess YOLO model results for returning JSON response.
35
+ :param outputs: Raw model output
36
+ :return: Processed results as a dictionary
37
+ """
38
+ detections = []
39
+ for result in outputs:
40
+ for box in result.boxes:
41
+ detections.append({
42
+ "class": self.model.names[int(box.cls)], # Get class name
43
+ "confidence": box.conf.tolist(), # Confidence score
44
+ "box": box.xyxy.tolist() # Bounding box coordinates
45
+ })
46
+ return {"detections": detections}
47
+
48
+ def __call__(self, inputs):
49
+ """
50
+ Full pipeline: preprocess, predict, postprocess.
51
+ :param inputs: JSON input data
52
+ :return: JSON output results
53
+ """
54
+ preprocessed_data = self.preprocess(inputs)
55
+ predictions = self.predict(preprocessed_data)
56
+ return self.postprocess(predictions)