Spaces:
Running
Running
import os | |
import cv2 | |
import imghdr | |
import shutil | |
import warnings | |
import numpy as np | |
import gradio as gr | |
from dataclasses import dataclass | |
from mivolo.predictor import Predictor | |
from utils import is_url, download_file, get_jpg_files, _L, MODEL_DIR, TMP_DIR | |
class Cfg: | |
detector_weights: str | |
checkpoint: str | |
device: str = "cpu" | |
with_persons: bool = True | |
disable_faces: bool = False | |
draw: bool = True | |
class ValidImgDetector: | |
predictor = None | |
def __init__(self): | |
detector_path = f"{MODEL_DIR}/yolov8x_person_face.pt" | |
age_gender_path = f"{MODEL_DIR}/model_imdb_cross_person_4.22_99.46.pth.tar" | |
predictor_cfg = Cfg(detector_path, age_gender_path) | |
self.predictor = Predictor(predictor_cfg) | |
def _detect( | |
self, | |
image: np.ndarray, | |
score_threshold: float, | |
iou_threshold: float, | |
mode: str, | |
predictor: Predictor, | |
) -> np.ndarray: | |
predictor.detector.detector_kwargs["conf"] = score_threshold | |
predictor.detector.detector_kwargs["iou"] = iou_threshold | |
if mode == "Use persons and faces": | |
use_persons = True | |
disable_faces = False | |
elif mode == "Use persons only": | |
use_persons = True | |
disable_faces = True | |
elif mode == "Use faces only": | |
use_persons = False | |
disable_faces = False | |
predictor.age_gender_model.meta.use_persons = use_persons | |
predictor.age_gender_model.meta.disable_faces = disable_faces | |
detected_objects, out_im = predictor.recognize(image) | |
has_child, has_female, has_male = False, False, False | |
if len(detected_objects.ages) > 0: | |
has_child = _L("是") if min(detected_objects.ages) < 18 else _L("否") | |
has_female = _L("是") if "female" in detected_objects.genders else _L("否") | |
has_male = _L("是") if "male" in detected_objects.genders else _L("否") | |
return out_im[:, :, ::-1], has_child, has_female, has_male | |
def valid_img(self, img_path): | |
image = cv2.imread(img_path) | |
return self._detect(image, 0.4, 0.7, "Use persons and faces", self.predictor) | |
def infer(photo: str): | |
status = "Success" | |
result = child = female = male = None | |
try: | |
if is_url(photo): | |
if os.path.exists(TMP_DIR): | |
shutil.rmtree(TMP_DIR) | |
photo = download_file(photo, f"{TMP_DIR}/download.jpg") | |
detector = ValidImgDetector() | |
if not photo or not os.path.exists(photo) or imghdr.what(photo) == None: | |
raise ValueError("请正确输入图片") | |
result, child, female, male = detector.valid_img(photo) | |
except Exception as e: | |
status = f"{e}" | |
return status, result, child, female, male | |
if __name__ == "__main__": | |
warnings.filterwarnings("ignore") | |
with gr.Blocks() as iface: | |
gr.Markdown(_L("# 性别年龄检测器")) | |
with gr.Tab(_L("上传模式")): | |
gr.Interface( | |
fn=infer, | |
inputs=gr.Image(label=_L("上传照片"), type="filepath"), | |
outputs=[ | |
gr.Textbox(label=_L("状态栏"), show_copy_button=True), | |
gr.Image( | |
label=_L("检测结果"), | |
type="numpy", | |
show_share_button=False, | |
), | |
gr.Textbox(label=_L("存在儿童")), | |
gr.Textbox(label=_L("存在女性")), | |
gr.Textbox(label=_L("存在男性")), | |
], | |
examples=get_jpg_files(f"{MODEL_DIR}/examples"), | |
flagging_mode="never", | |
cache_examples=False, | |
) | |
with gr.Tab(_L("在线模式")): | |
gr.Interface( | |
fn=infer, | |
inputs=gr.Textbox( | |
label=_L("网络图片链接"), | |
show_copy_button=True, | |
), | |
outputs=[ | |
gr.Textbox(label=_L("状态栏"), show_copy_button=True), | |
gr.Image( | |
label=_L("检测结果"), | |
type="numpy", | |
show_share_button=False, | |
), | |
gr.Textbox(label=_L("存在儿童")), | |
gr.Textbox(label=_L("存在女性")), | |
gr.Textbox(label=_L("存在男性")), | |
], | |
flagging_mode="never", | |
) | |
iface.launch() | |