HaohuaLv's picture
Update app.py
f7dbdd4
from transformers import pipeline, ViTModel, AutoImageProcessor
from PIL import Image
import gradio as gr
import torch
import os
detector = pipeline(model="google/owlvit-base-patch32", task="zero-shot-object-detection")
model = ViTModel.from_pretrained("google/vit-base-patch16-224")
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
candidates = []
def extract_face(input_image):
predictions = detector(
input_image,
candidate_labels=["human face"],
)
scores = [prediction["score"] for prediction in predictions]
max_score_box = tuple(predictions[scores == max(scores)]["box"].values())
face_image = input_image.crop(max_score_box)
return face_image
def load_candidates(candidate_dir):
assert os.path.exists(candidate_dir), f"Path candidate_dir {candidate_dir} is not exist."
candidates = []
candidate_labels = os.listdir(candidate_dir)
for candidate_label in candidate_labels:
image_paths = os.listdir(os.path.join(candidate_dir, candidate_label))
images = [Image.open(os.path.join(candidate_dir, candidate_label, image_path)).convert("RGB") for image_path in image_paths if image_path.endswith((".jpg", ".png", ".jpeg", ".bmp"))]
candidates.append(dict(label=candidate_label, images=images))
return candidates
def extract_faces(candidates):
for candidate in candidates:
faces = []
for image in candidate["images"]:
faces.append(extract_face(image))
candidate["faces"] = faces
return candidates
def extract_featrue(candidates, target):
for candidate in candidates:
target_images = candidate[target]
pixel_values = image_processor(target_images, return_tensors="pt")["pixel_values"]
features = model(pixel_values)["pooler_output"]
feature = features.mean(0)
candidate["feature"] = feature
return candidates
def load_candidates_face_feature(candidates):
candidates = extract_faces(candidates)
candidates = extract_featrue(candidates, "faces")
return candidates
def compare_with_candidates(detectd_face, candidates):
pixel_values = image_processor(detectd_face, return_tensors="pt")["pixel_values"]
detectd_feature = model(pixel_values)["pooler_output"].squeeze(0)
sims = []
labels = [candidate["label"] for candidate in candidates]
for candidate in candidates:
sim = torch.cosine_similarity(detectd_feature, candidate["feature"], dim=0).item()
sims.append(sim)
return labels[sims.index(max(sims))]
def face_recognition(detected_image):
predictions = detector(
detected_image,
candidate_labels=["human face"],
)
labels = []
for p in predictions:
box = tuple(p["box"].values())
label = compare_with_candidates(detected_image.crop(box), candidates)
labels.append((box, label))
return detected_image, labels
def load_candidates_in_cache(candidate_dir):
global candidates
candidates = load_candidates(candidate_dir)
candidates = load_candidates_face_feature(candidates)
def main():
with gr.Blocks() as demo:
with gr.Row():
detected_image = gr.Image(type="pil", label="detected_image")
output_image = gr.AnnotatedImage(type="pil", label="output_image")
with gr.Row():
candidate_dir = gr.Textbox(label="candidate_dir")
load_candidates_btn = gr.Button("Load", variant="secondary", size="sm")
btn = gr.Button("Face Recognition", variant="primary")
load_candidates_btn.click(fn=load_candidates_in_cache, inputs=[candidate_dir])
btn.click(fn=face_recognition, inputs=[detected_image], outputs=[output_image])
demo.launch(debug=True)
if __name__ == "__main__":
main()