from transformers import pipeline, ViTModel, AutoImageProcessor from PIL import Image import gradio as gr import torch import os detector = pipeline(model="google/owlvit-base-patch32", task="zero-shot-object-detection") model = ViTModel.from_pretrained("google/vit-base-patch16-224") image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") candidates = [] def extract_face(input_image): predictions = detector( input_image, candidate_labels=["human face"], ) scores = [prediction["score"] for prediction in predictions] max_score_box = tuple(predictions[scores == max(scores)]["box"].values()) face_image = input_image.crop(max_score_box) return face_image def load_candidates(candidate_dir): assert os.path.exists(candidate_dir), f"Path candidate_dir {candidate_dir} is not exist." candidates = [] candidate_labels = os.listdir(candidate_dir) for candidate_label in candidate_labels: image_paths = os.listdir(os.path.join(candidate_dir, candidate_label)) images = [Image.open(os.path.join(candidate_dir, candidate_label, image_path)).convert("RGB") for image_path in image_paths if image_path.endswith((".jpg", ".png", ".jpeg", ".bmp"))] candidates.append(dict(label=candidate_label, images=images)) return candidates def extract_faces(candidates): for candidate in candidates: faces = [] for image in candidate["images"]: faces.append(extract_face(image)) candidate["faces"] = faces return candidates def extract_featrue(candidates, target): for candidate in candidates: target_images = candidate[target] pixel_values = image_processor(target_images, return_tensors="pt")["pixel_values"] features = model(pixel_values)["pooler_output"] feature = features.mean(0) candidate["feature"] = feature return candidates def load_candidates_face_feature(candidates): candidates = extract_faces(candidates) candidates = extract_featrue(candidates, "faces") return candidates def compare_with_candidates(detectd_face, candidates): pixel_values = image_processor(detectd_face, return_tensors="pt")["pixel_values"] detectd_feature = model(pixel_values)["pooler_output"].squeeze(0) sims = [] labels = [candidate["label"] for candidate in candidates] for candidate in candidates: sim = torch.cosine_similarity(detectd_feature, candidate["feature"], dim=0).item() sims.append(sim) return labels[sims.index(max(sims))] def face_recognition(detected_image): predictions = detector( detected_image, candidate_labels=["human face"], ) labels = [] for p in predictions: box = tuple(p["box"].values()) label = compare_with_candidates(detected_image.crop(box), candidates) labels.append((box, label)) return detected_image, labels def load_candidates_in_cache(candidate_dir): global candidates candidates = load_candidates(candidate_dir) candidates = load_candidates_face_feature(candidates) def main(): with gr.Blocks() as demo: with gr.Row(): detected_image = gr.Image(type="pil", label="detected_image") output_image = gr.AnnotatedImage(type="pil", label="output_image") with gr.Row(): candidate_dir = gr.Textbox(label="candidate_dir") load_candidates_btn = gr.Button("Load", variant="secondary", size="sm") btn = gr.Button("Face Recognition", variant="primary") load_candidates_btn.click(fn=load_candidates_in_cache, inputs=[candidate_dir]) btn.click(fn=face_recognition, inputs=[detected_image], outputs=[output_image]) demo.launch(debug=True) if __name__ == "__main__": main()