waste-classifier / trash_detector.py
santit96's picture
Stop versioning the model checkpoints, now they are downloaded from huggingface. Add env vars
dd14920
raw
history blame
2.28 kB
import os
import numpy as np
import torch
from fastai.vision.all import load_learner
from huggingface_hub import hf_hub_download
from constants import (CLAS_FILENAME, CLAS_FILEPATH, CLAS_THRESHOLD,
DET_FILENAME, DET_FILEPATH, DET_NAME, DET_THRESHOLD,
DEVICE, HF_CLAS_REPO_NAME, HF_DET_REPO_NAME,
MODELS_PATH)
from efficientdet.efficientdet import get_transforms, rescale_bboxes, set_model
def localize_trash(im):
# detector, if checkpoint doesn't exist then download from hf
if not os.path.exists(DET_FILEPATH):
hf_hub_download(HF_DET_REPO_NAME, DET_FILENAME, local_dir=MODELS_PATH)
detector = set_model(DET_NAME, 1, DET_FILEPATH, DEVICE)
detector.eval()
# mean-std normalize the input image (batch-size: 1)
img = get_transforms(im)
# propagate through the model
outputs = detector(img.to(DEVICE))
# keep only predictions above set confidence
bboxes_keep = outputs[0, outputs[0, :, 4] > DET_THRESHOLD]
probas = bboxes_keep[:, 4:]
# convert boxes to image scales
bboxes_scaled = rescale_bboxes(bboxes_keep[:, :4], im.size, tuple(img.size()[2:]))
return probas, bboxes_scaled
def classify_trash(im, probas, bboxes_scaled):
# classifier, if checkpoint doesn't exist then download from hf
if not os.path.exists(CLAS_FILEPATH):
hf_hub_download(HF_CLAS_REPO_NAME, CLAS_FILENAME, local_dir=MODELS_PATH)
classifier = load_learner(CLAS_FILEPATH)
bboxes_final = []
cls_prob = []
for p, (xmin, ymin, xmax, ymax) in zip(probas, bboxes_scaled.tolist()):
img = im.crop((xmin, ymin, xmax, ymax))
outputs = classifier.predict(img)
p[1] = torch.topk(outputs[2], k=1).indices.squeeze(0).item()
p[0] = torch.max(np.trunc(outputs[2] * 100))
if p[0] >= CLAS_THRESHOLD * 100:
bboxes_final.append((xmin, ymin, xmax, ymax))
cls_prob.append(p)
return cls_prob, bboxes_final
def detect_trash(img):
# prepare models for evaluation
torch.set_grad_enabled(False)
# 1) Localize
probas, bboxes_scaled = localize_trash(img)
# 2) Classify
cls_prob, bboxes_final = classify_trash(img, probas, bboxes_scaled)
return cls_prob, bboxes_final