File size: 1,165 Bytes
69ef5c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from ultralytics import YOLO

class PearDetectionModel:
    def __init__(self, config) -> None:
        self.device = (
            torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        )
        self.model = YOLO(config["model_path"], task="detect")

        self.names = config["classes"]

    def detect(self, img):
        results = self.model.predict(img)
        return results[0].boxes.cpu().numpy()

    def inference(self, img):
        pred = self.detect(img)

        # remove the box with confidence lower than 0.9 if no "burn_bbox" is detected, else 0.8
        pred = (
            pred[pred.conf > 0.8]
            if all([pred != "burn_bbox" for pred in self.names])
            else pred[pred.conf > 0.5]
        )
        labels = [self.names[int(cat)] for cat in pred.cls]

        # if any classes rather than "normal_pear_box" is detected, return 0 else return 1
        if any([label == "burn_bbox" for label in labels]):
            return 1, pred.xyxy, pred.conf
        else:
            return 0, pred.xyxy, pred.conf

    def _preporcess(self, img):
        pass