HaohuaLv commited on
Commit
cd7ae94
·
1 Parent(s): 41448a1

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -0
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline, ViTModel, AutoImageProcessor
2
+ from PIL import Image
3
+ import gradio as gr
4
+ import torch
5
+ import os
6
+
7
+
8
+ detector = pipeline(model="google/owlvit-base-patch32", task="zero-shot-object-detection")
9
+ model = ViTModel.from_pretrained("google/vit-base-patch16-224")
10
+ image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
11
+
12
+ candidates = []
13
+
14
+ def extract_face(input_image):
15
+ predictions = detector(
16
+ input_image,
17
+ candidate_labels=["human face"],
18
+ )
19
+ scores = [prediction["score"] for prediction in predictions]
20
+ max_score_box = tuple(predictions[scores == max(scores)]["box"].values())
21
+ face_image = input_image.crop(max_score_box)
22
+ return face_image
23
+
24
+
25
+
26
+ def load_candidates(candidate_dir):
27
+ assert os.path.exists(candidate_dir), f"Path candidate_dir {candidate_dir} is not exist."
28
+
29
+ candidates = []
30
+ candidate_labels = os.listdir(candidate_dir)
31
+ for candidate_label in candidate_labels:
32
+ image_paths = os.listdir(os.path.join(candidate_dir, candidate_label))
33
+ images = [Image.open(os.path.join(candidate_dir, candidate_label, image_path)).convert("RGB") for image_path in image_paths if image_path.endswith((".jpg", ".png", ".jpeg", ".bmp"))]
34
+ candidates.append(dict(label=candidate_label, images=images))
35
+ return candidates
36
+
37
+ def extract_faces(candidates):
38
+ for candidate in candidates:
39
+ faces = []
40
+ for image in candidate["images"]:
41
+ faces.append(extract_face(image))
42
+ candidate["faces"] = faces
43
+ return candidates
44
+
45
+ def extract_featrue(candidates, target):
46
+ for candidate in candidates:
47
+ target_images = candidate[target]
48
+ pixel_values = image_processor(target_images, return_tensors="pt")["pixel_values"]
49
+ features = model(pixel_values)["pooler_output"]
50
+ feature = features.mean(0)
51
+ candidate["feature"] = feature
52
+ return candidates
53
+
54
+
55
+ def load_candidates_face_feature(candidates):
56
+ candidates = extract_faces(candidates)
57
+ candidates = extract_featrue(candidates, "faces")
58
+ return candidates
59
+
60
+ def compare_with_candidates(detectd_face, candidates):
61
+ pixel_values = image_processor(detectd_face, return_tensors="pt")["pixel_values"]
62
+ detectd_feature = model(pixel_values)["pooler_output"].squeeze(0)
63
+ sims = []
64
+ labels = [candidate["label"] for candidate in candidates]
65
+ for candidate in candidates:
66
+ sim = torch.cosine_similarity(detectd_feature, candidate["feature"], dim=0).item()
67
+ sims.append(sim)
68
+ return labels[sims.index(max(sims))]
69
+
70
+ def face_recognition(detected_image):
71
+ predictions = detector(
72
+ detected_image,
73
+ candidate_labels=["human face"],
74
+ )
75
+ labels = []
76
+ for p in predictions:
77
+ box = tuple(p["box"].values())
78
+ label = compare_with_candidates(detected_image.crop(box), candidates)
79
+ labels.append((box, label))
80
+
81
+ return detected_image, labels
82
+
83
+ def load_candidates_in_cache(candidate_dir):
84
+ global candidates
85
+ candidates = load_candidates(candidate_dir)
86
+ candidates = load_candidates_face_feature(candidates)
87
+
88
+
89
+ def main():
90
+ with gr.Blocks() as demo:
91
+ with gr.Row():
92
+ detected_image = gr.Image(type="pil", label="detected_image")
93
+ output_image = gr.AnnotatedImage(type="pil", label="output_image")
94
+
95
+ with gr.Row():
96
+ candidate_dir = gr.Textbox(label="candidate_dir")
97
+ load_candidates_btn = gr.Button("Load", variant="secondary", size="sm")
98
+ btn = gr.Button("Face Recognition", variant="primary")
99
+ load_candidates_btn.click(fn=load_candidates_in_cache, inputs=[candidate_dir])
100
+ btn.click(fn=face_recognition, inputs=[detected_image], outputs=[output_image])
101
+
102
+ demo.launch(server_port=7862)
103
+
104
+ if __name__ == "__main__":
105
+ main()