File size: 4,576 Bytes
eaa17fe
8ac55ce
 
eaa17fe
8ac55ce
 
eaa17fe
8ac55ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d3fab6
 
 
8ac55ce
 
 
 
 
 
 
 
 
 
 
eaa17fe
8ac55ce
 
 
eaa17fe
8ac55ce
eaa17fe
8ac55ce
 
 
eaa17fe
8ac55ce
eaa17fe
 
8ac55ce
eaa17fe
8ac55ce
eaa17fe
 
 
8ac55ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import cv2
import imghdr
import shutil
import warnings
import numpy as np
import gradio as gr
from dataclasses import dataclass
from mivolo.predictor import Predictor
from utils import is_url, download_file, get_jpg_files, _L, MODEL_DIR, TMP_DIR


@dataclass
class Cfg:
    detector_weights: str
    checkpoint: str
    device: str = "cpu"
    with_persons: bool = True
    disable_faces: bool = False
    draw: bool = True


class ValidImgDetector:
    predictor = None

    def __init__(self):
        detector_path = f"{MODEL_DIR}/yolov8x_person_face.pt"
        age_gender_path = f"{MODEL_DIR}/model_imdb_cross_person_4.22_99.46.pth.tar"
        predictor_cfg = Cfg(detector_path, age_gender_path)
        self.predictor = Predictor(predictor_cfg)

    def _detect(
        self,
        image: np.ndarray,
        score_threshold: float,
        iou_threshold: float,
        mode: str,
        predictor: Predictor,
    ) -> np.ndarray:
        predictor.detector.detector_kwargs["conf"] = score_threshold
        predictor.detector.detector_kwargs["iou"] = iou_threshold
        if mode == "Use persons and faces":
            use_persons = True
            disable_faces = False

        elif mode == "Use persons only":
            use_persons = True
            disable_faces = True

        elif mode == "Use faces only":
            use_persons = False
            disable_faces = False

        predictor.age_gender_model.meta.use_persons = use_persons
        predictor.age_gender_model.meta.disable_faces = disable_faces
        detected_objects, out_im = predictor.recognize(image)
        has_child, has_female, has_male = False, False, False
        if len(detected_objects.ages) > 0:
            has_child = _L("是") if min(detected_objects.ages) < 18 else _L("否")
            has_female = _L("是") if "female" in detected_objects.genders else _L("否")
            has_male = _L("是") if "male" in detected_objects.genders else _L("否")

        return out_im[:, :, ::-1], has_child, has_female, has_male

    def valid_img(self, img_path):
        image = cv2.imread(img_path)
        return self._detect(image, 0.4, 0.7, "Use persons and faces", self.predictor)


def infer(photo: str):
    status = "Success"
    result = child = female = male = None
    try:
        if is_url(photo):
            if os.path.exists(TMP_DIR):
                shutil.rmtree(TMP_DIR)

            photo = download_file(photo, f"{TMP_DIR}/download.jpg")

        detector = ValidImgDetector()
        if not photo or not os.path.exists(photo) or imghdr.what(photo) == None:
            raise ValueError("请正确输入图片")

        result, child, female, male = detector.valid_img(photo)

    except Exception as e:
        status = f"{e}"

    return status, result, child, female, male


if __name__ == "__main__":
    warnings.filterwarnings("ignore")
    with gr.Blocks() as iface:
        gr.Markdown(_L("# 性别年龄检测器"))
        with gr.Tab(_L("上传模式")):
            gr.Interface(
                fn=infer,
                inputs=gr.Image(label=_L("上传照片"), type="filepath"),
                outputs=[
                    gr.Textbox(label=_L("状态栏"), show_copy_button=True),
                    gr.Image(
                        label=_L("检测结果"),
                        type="numpy",
                        show_share_button=False,
                    ),
                    gr.Textbox(label=_L("存在儿童")),
                    gr.Textbox(label=_L("存在女性")),
                    gr.Textbox(label=_L("存在男性")),
                ],
                examples=get_jpg_files(f"{MODEL_DIR}/examples"),
                flagging_mode="never",
                cache_examples=False,
            )

        with gr.Tab(_L("在线模式")):
            gr.Interface(
                fn=infer,
                inputs=gr.Textbox(
                    label=_L("网络图片链接"),
                    show_copy_button=True,
                ),
                outputs=[
                    gr.Textbox(label=_L("状态栏"), show_copy_button=True),
                    gr.Image(
                        label=_L("检测结果"),
                        type="numpy",
                        show_share_button=False,
                    ),
                    gr.Textbox(label=_L("存在儿童")),
                    gr.Textbox(label=_L("存在女性")),
                    gr.Textbox(label=_L("存在男性")),
                ],
                flagging_mode="never",
            )

    iface.launch()