File size: 3,810 Bytes
cd7ae94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7dbdd4
cd7ae94
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from transformers import pipeline, ViTModel, AutoImageProcessor
from PIL import Image
import gradio as gr
import torch
import os


detector = pipeline(model="google/owlvit-base-patch32", task="zero-shot-object-detection")
model = ViTModel.from_pretrained("google/vit-base-patch16-224")
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")

candidates = []

def extract_face(input_image):
    predictions = detector(
        input_image,
        candidate_labels=["human face"],
    )
    scores = [prediction["score"] for prediction in predictions]
    max_score_box = tuple(predictions[scores == max(scores)]["box"].values())
    face_image = input_image.crop(max_score_box)
    return face_image



def load_candidates(candidate_dir):
    assert os.path.exists(candidate_dir), f"Path candidate_dir {candidate_dir} is not exist."

    candidates = []
    candidate_labels = os.listdir(candidate_dir)
    for candidate_label in candidate_labels:
        image_paths = os.listdir(os.path.join(candidate_dir, candidate_label))
        images = [Image.open(os.path.join(candidate_dir, candidate_label, image_path)).convert("RGB") for image_path in image_paths if image_path.endswith((".jpg", ".png", ".jpeg", ".bmp"))]
        candidates.append(dict(label=candidate_label, images=images))
    return candidates

def extract_faces(candidates):
    for candidate in candidates:
        faces = []
        for image in candidate["images"]:
            faces.append(extract_face(image))
        candidate["faces"] = faces
    return candidates

def extract_featrue(candidates, target):
    for candidate in candidates:
        target_images = candidate[target]
        pixel_values = image_processor(target_images, return_tensors="pt")["pixel_values"]
        features = model(pixel_values)["pooler_output"]
        feature = features.mean(0)
        candidate["feature"] = feature
    return candidates


def load_candidates_face_feature(candidates):
    candidates = extract_faces(candidates)
    candidates = extract_featrue(candidates, "faces")
    return candidates

def compare_with_candidates(detectd_face, candidates):
    pixel_values = image_processor(detectd_face, return_tensors="pt")["pixel_values"]
    detectd_feature = model(pixel_values)["pooler_output"].squeeze(0)
    sims = []
    labels = [candidate["label"] for candidate in candidates]
    for candidate in candidates:
        sim = torch.cosine_similarity(detectd_feature, candidate["feature"], dim=0).item()
        sims.append(sim)
    return labels[sims.index(max(sims))]

def face_recognition(detected_image):
    predictions = detector(
        detected_image,
        candidate_labels=["human face"],
    )
    labels = []
    for p in predictions:
        box = tuple(p["box"].values())
        label = compare_with_candidates(detected_image.crop(box), candidates)
        labels.append((box, label))

    return detected_image, labels

def load_candidates_in_cache(candidate_dir):
    global candidates
    candidates = load_candidates(candidate_dir)
    candidates = load_candidates_face_feature(candidates)


def main():
    with gr.Blocks() as demo:
        with gr.Row():
            detected_image = gr.Image(type="pil", label="detected_image")
            output_image = gr.AnnotatedImage(type="pil", label="output_image")

        with gr.Row():
            candidate_dir = gr.Textbox(label="candidate_dir")
            load_candidates_btn = gr.Button("Load", variant="secondary", size="sm")
        btn = gr.Button("Face Recognition", variant="primary")
        load_candidates_btn.click(fn=load_candidates_in_cache, inputs=[candidate_dir])
        btn.click(fn=face_recognition, inputs=[detected_image], outputs=[output_image])

    demo.launch(debug=True)

if __name__ == "__main__":
    main()