MiVOLO / app.py
admin
sync ms
3d3fab6
raw
history blame
4.58 kB
import os
import cv2
import imghdr
import shutil
import warnings
import numpy as np
import gradio as gr
from dataclasses import dataclass
from mivolo.predictor import Predictor
from utils import is_url, download_file, get_jpg_files, _L, MODEL_DIR, TMP_DIR
@dataclass
class Cfg:
detector_weights: str
checkpoint: str
device: str = "cpu"
with_persons: bool = True
disable_faces: bool = False
draw: bool = True
class ValidImgDetector:
predictor = None
def __init__(self):
detector_path = f"{MODEL_DIR}/yolov8x_person_face.pt"
age_gender_path = f"{MODEL_DIR}/model_imdb_cross_person_4.22_99.46.pth.tar"
predictor_cfg = Cfg(detector_path, age_gender_path)
self.predictor = Predictor(predictor_cfg)
def _detect(
self,
image: np.ndarray,
score_threshold: float,
iou_threshold: float,
mode: str,
predictor: Predictor,
) -> np.ndarray:
predictor.detector.detector_kwargs["conf"] = score_threshold
predictor.detector.detector_kwargs["iou"] = iou_threshold
if mode == "Use persons and faces":
use_persons = True
disable_faces = False
elif mode == "Use persons only":
use_persons = True
disable_faces = True
elif mode == "Use faces only":
use_persons = False
disable_faces = False
predictor.age_gender_model.meta.use_persons = use_persons
predictor.age_gender_model.meta.disable_faces = disable_faces
detected_objects, out_im = predictor.recognize(image)
has_child, has_female, has_male = False, False, False
if len(detected_objects.ages) > 0:
has_child = _L("是") if min(detected_objects.ages) < 18 else _L("否")
has_female = _L("是") if "female" in detected_objects.genders else _L("否")
has_male = _L("是") if "male" in detected_objects.genders else _L("否")
return out_im[:, :, ::-1], has_child, has_female, has_male
def valid_img(self, img_path):
image = cv2.imread(img_path)
return self._detect(image, 0.4, 0.7, "Use persons and faces", self.predictor)
def infer(photo: str):
status = "Success"
result = child = female = male = None
try:
if is_url(photo):
if os.path.exists(TMP_DIR):
shutil.rmtree(TMP_DIR)
photo = download_file(photo, f"{TMP_DIR}/download.jpg")
detector = ValidImgDetector()
if not photo or not os.path.exists(photo) or imghdr.what(photo) == None:
raise ValueError("请正确输入图片")
result, child, female, male = detector.valid_img(photo)
except Exception as e:
status = f"{e}"
return status, result, child, female, male
if __name__ == "__main__":
warnings.filterwarnings("ignore")
with gr.Blocks() as iface:
gr.Markdown(_L("# 性别年龄检测器"))
with gr.Tab(_L("上传模式")):
gr.Interface(
fn=infer,
inputs=gr.Image(label=_L("上传照片"), type="filepath"),
outputs=[
gr.Textbox(label=_L("状态栏"), show_copy_button=True),
gr.Image(
label=_L("检测结果"),
type="numpy",
show_share_button=False,
),
gr.Textbox(label=_L("存在儿童")),
gr.Textbox(label=_L("存在女性")),
gr.Textbox(label=_L("存在男性")),
],
examples=get_jpg_files(f"{MODEL_DIR}/examples"),
flagging_mode="never",
cache_examples=False,
)
with gr.Tab(_L("在线模式")):
gr.Interface(
fn=infer,
inputs=gr.Textbox(
label=_L("网络图片链接"),
show_copy_button=True,
),
outputs=[
gr.Textbox(label=_L("状态栏"), show_copy_button=True),
gr.Image(
label=_L("检测结果"),
type="numpy",
show_share_button=False,
),
gr.Textbox(label=_L("存在儿童")),
gr.Textbox(label=_L("存在女性")),
gr.Textbox(label=_L("存在男性")),
],
flagging_mode="never",
)
iface.launch()