import os import numpy as np import torch from fastai.vision.all import load_learner from huggingface_hub import hf_hub_download from constants import (CLAS_FILENAME, CLAS_FILEPATH, CLAS_THRESHOLD, DET_FILENAME, DET_FILEPATH, DET_NAME, DET_THRESHOLD, DEVICE, HF_CLAS_REPO_NAME, HF_DET_REPO_NAME, MODELS_PATH) from efficientdet.efficientdet import get_transforms, rescale_bboxes, set_model def localize_trash(im): # detector, if checkpoint doesn't exist then download from hf if not os.path.exists(DET_FILEPATH): hf_hub_download(HF_DET_REPO_NAME, DET_FILENAME, local_dir=MODELS_PATH) detector = set_model(DET_NAME, 1, DET_FILEPATH, DEVICE) detector.eval() # mean-std normalize the input image (batch-size: 1) img = get_transforms(im) # propagate through the model outputs = detector(img.to(DEVICE)) # keep only predictions above set confidence bboxes_keep = outputs[0, outputs[0, :, 4] > DET_THRESHOLD] probas = bboxes_keep[:, 4:] # convert boxes to image scales bboxes_scaled = rescale_bboxes(bboxes_keep[:, :4], im.size, tuple(img.size()[2:])) return probas, bboxes_scaled def classify_trash(im, probas, bboxes_scaled): # classifier, if checkpoint doesn't exist then download from hf if not os.path.exists(CLAS_FILEPATH): hf_hub_download(HF_CLAS_REPO_NAME, CLAS_FILENAME, local_dir=MODELS_PATH) classifier = load_learner(CLAS_FILEPATH) bboxes_final = [] cls_prob = [] for p, (xmin, ymin, xmax, ymax) in zip(probas, bboxes_scaled.tolist()): img = im.crop((xmin, ymin, xmax, ymax)) outputs = classifier.predict(img) p[1] = torch.topk(outputs[2], k=1).indices.squeeze(0).item() p[0] = torch.max(np.trunc(outputs[2] * 100)) if p[0] >= CLAS_THRESHOLD * 100: bboxes_final.append((xmin, ymin, xmax, ymax)) cls_prob.append(p) return cls_prob, bboxes_final def detect_trash(img): # prepare models for evaluation torch.set_grad_enabled(False) # 1) Localize probas, bboxes_scaled = localize_trash(img) # 2) Classify cls_prob, bboxes_final = classify_trash(img, probas, bboxes_scaled) return cls_prob, bboxes_final