import torch from ultralytics import YOLO class PearDetectionModel: def __init__(self, config) -> None: self.device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) self.model = YOLO(config["model_path"], task="detect") self.names = config["classes"] def detect(self, img): results = self.model.predict(img) return results[0].boxes.cpu().numpy() def inference(self, img): pred = self.detect(img) # remove the box with confidence lower than 0.9 if no "burn_bbox" is detected, else 0.8 pred = ( pred[pred.conf > 0.8] if all([pred != "burn_bbox" for pred in self.names]) else pred[pred.conf > 0.5] ) labels = [self.names[int(cat)] for cat in pred.cls] # if any classes rather than "normal_pear_box" is detected, return 0 else return 1 if any([label == "burn_bbox" for label in labels]): return 1, pred.xyxy, pred.conf else: return 0, pred.xyxy, pred.conf def _preporcess(self, img): pass