MiVOLO / app.py
jaimin's picture
Update app.py
1011fc1 verified
raw
history blame
3.81 kB
import os
import cv2
import imghdr
import shutil
import warnings
import numpy as np
import gradio as gr
from dataclasses import dataclass
from mivolo.predictor import Predictor
from utils import is_url, download_file, get_jpg_files, MODEL_DIR
TMP_DIR = "./__pycache__"
@dataclass
class Cfg:
detector_weights: str
checkpoint: str
device: str = "cpu"
with_persons: bool = True
disable_faces: bool = False
draw: bool = True
class ValidImgDetector:
predictor = None
def __init__(self):
detector_path = f"{MODEL_DIR}/yolov8x_person_face.pt"
age_gender_path = f"{MODEL_DIR}/model_imdb_cross_person_4.22_99.46.pth.tar"
predictor_cfg = Cfg(detector_path, age_gender_path)
self.predictor = Predictor(predictor_cfg)
def _detect(
self,
image: np.ndarray,
score_threshold: float,
iou_threshold: float,
mode: str,
predictor: Predictor,
) -> np.ndarray:
# input is RGB image, output must be RGB too
predictor.detector.detector_kwargs["conf"] = score_threshold
predictor.detector.detector_kwargs["iou"] = iou_threshold
if mode == "Use persons and faces":
use_persons = True
disable_faces = False
elif mode == "Use persons only":
use_persons = True
disable_faces = True
elif mode == "Use faces only":
use_persons = False
disable_faces = False
predictor.age_gender_model.meta.use_persons = use_persons
predictor.age_gender_model.meta.disable_faces = disable_faces
# image = image[:, :, ::-1] # RGB -> BGR
detected_objects, out_im = predictor.recognize(image)
has_child, has_female, has_male = False, False, False
if len(detected_objects.ages) > 0:
has_child = min(detected_objects.ages) < 18
has_female = "female" in detected_objects.genders
has_male = "male" in detected_objects.genders
return out_im[:, :, ::-1], has_child, has_female, has_male
def valid_img(self, img_path):
image = cv2.imread(img_path)
return self._detect(image, 0.4, 0.7, "Use persons and faces", self.predictor)
def infer(photo: str):
if is_url(photo):
if os.path.exists(TMP_DIR):
shutil.rmtree(TMP_DIR)
photo = download_file(photo, f"{TMP_DIR}/download.jpg")
detector = ValidImgDetector()
if not photo or not os.path.exists(photo) or imghdr.what(photo) is None:
return None, None, None, "Please input the image correctly"
return detector.valid_img(photo)
if __name__ == "__main__":
with gr.Blocks() as iface:
warnings.filterwarnings("ignore")
with gr.Tab("Upload Mode"):
gr.Interface(
fn=infer,
inputs=gr.Image(label="Upload Photo", type="filepath"),
outputs=[
gr.Image(label="Detection Result", type="numpy"),
gr.Textbox(label="Has Child"),
gr.Textbox(label="Has Female"),
gr.Textbox(label="Has Male"),
],
examples=get_jpg_files(f"{MODEL_DIR}/examples"),
allow_flagging="never",
cache_examples=False,
)
with gr.Tab("Online Mode"):
gr.Interface(
fn=infer,
inputs=gr.Textbox(label="Online Picture URL"),
outputs=[
gr.Image(label="Detection Result", type="numpy"),
gr.Textbox(label="Has Child"),
gr.Textbox(label="Has Female"),
gr.Textbox(label="Has Male"),
],
allow_flagging="never",
)
iface.launch()