Spaces:
Runtime error
Runtime error
from transformers import pipeline, ViTModel, AutoImageProcessor | |
from PIL import Image | |
import gradio as gr | |
import torch | |
import os | |
detector = pipeline(model="google/owlvit-base-patch32", task="zero-shot-object-detection") | |
model = ViTModel.from_pretrained("google/vit-base-patch16-224") | |
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") | |
candidates = [] | |
def extract_face(input_image): | |
predictions = detector( | |
input_image, | |
candidate_labels=["human face"], | |
) | |
scores = [prediction["score"] for prediction in predictions] | |
max_score_box = tuple(predictions[scores == max(scores)]["box"].values()) | |
face_image = input_image.crop(max_score_box) | |
return face_image | |
def load_candidates(candidate_dir): | |
assert os.path.exists(candidate_dir), f"Path candidate_dir {candidate_dir} is not exist." | |
candidates = [] | |
candidate_labels = os.listdir(candidate_dir) | |
for candidate_label in candidate_labels: | |
image_paths = os.listdir(os.path.join(candidate_dir, candidate_label)) | |
images = [Image.open(os.path.join(candidate_dir, candidate_label, image_path)).convert("RGB") for image_path in image_paths if image_path.endswith((".jpg", ".png", ".jpeg", ".bmp"))] | |
candidates.append(dict(label=candidate_label, images=images)) | |
return candidates | |
def extract_faces(candidates): | |
for candidate in candidates: | |
faces = [] | |
for image in candidate["images"]: | |
faces.append(extract_face(image)) | |
candidate["faces"] = faces | |
return candidates | |
def extract_featrue(candidates, target): | |
for candidate in candidates: | |
target_images = candidate[target] | |
pixel_values = image_processor(target_images, return_tensors="pt")["pixel_values"] | |
features = model(pixel_values)["pooler_output"] | |
feature = features.mean(0) | |
candidate["feature"] = feature | |
return candidates | |
def load_candidates_face_feature(candidates): | |
candidates = extract_faces(candidates) | |
candidates = extract_featrue(candidates, "faces") | |
return candidates | |
def compare_with_candidates(detectd_face, candidates): | |
pixel_values = image_processor(detectd_face, return_tensors="pt")["pixel_values"] | |
detectd_feature = model(pixel_values)["pooler_output"].squeeze(0) | |
sims = [] | |
labels = [candidate["label"] for candidate in candidates] | |
for candidate in candidates: | |
sim = torch.cosine_similarity(detectd_feature, candidate["feature"], dim=0).item() | |
sims.append(sim) | |
return labels[sims.index(max(sims))] | |
def face_recognition(detected_image): | |
predictions = detector( | |
detected_image, | |
candidate_labels=["human face"], | |
) | |
labels = [] | |
for p in predictions: | |
box = tuple(p["box"].values()) | |
label = compare_with_candidates(detected_image.crop(box), candidates) | |
labels.append((box, label)) | |
return detected_image, labels | |
def load_candidates_in_cache(candidate_dir): | |
global candidates | |
candidates = load_candidates(candidate_dir) | |
candidates = load_candidates_face_feature(candidates) | |
def main(): | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
detected_image = gr.Image(type="pil", label="detected_image") | |
output_image = gr.AnnotatedImage(type="pil", label="output_image") | |
with gr.Row(): | |
candidate_dir = gr.Textbox(label="candidate_dir") | |
load_candidates_btn = gr.Button("Load", variant="secondary", size="sm") | |
btn = gr.Button("Face Recognition", variant="primary") | |
load_candidates_btn.click(fn=load_candidates_in_cache, inputs=[candidate_dir]) | |
btn.click(fn=face_recognition, inputs=[detected_image], outputs=[output_image]) | |
demo.launch(debug=True) | |
if __name__ == "__main__": | |
main() |