import torch from torch import nn import torchvision.transforms as transforms import cv2 import numpy as np import gradio as gr from PIL import Image from facenet_pytorch import MTCNN from transformers import ViTImageProcessor, ViTModel import pickle import time # Define the ViT class class ViT(nn.Module): def __init__(self, base_model): super(ViT, self).__init__() self.base_model = base_model def forward(self, x): x = self.base_model(x).pooler_output return x # Load the model and processor model_name = "google/vit-base-patch16-224" processor = ViTImageProcessor.from_pretrained(model_name) base_model = ViTModel.from_pretrained("WinKawaks/vit-small-patch16-224") model = ViT(base_model) model.load_state_dict(torch.load('faceViT6.pth')) # Set the model to evaluation mode model.eval() # Check if CUDA is available and move the model to GPU if it is device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) # Initialize MTCNN for face detection mtcnn = MTCNN(keep_all=True, min_face_size=20, thresholds=[0.6, 0.7, 0.7], device=device) # Define the transformation transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) # Load the database of embeddings with open('face_database_me.pkl', 'rb') as f: database = pickle.load(f) def cosine_similarity(embedding1, embedding2): similarity = torch.nn.functional.cosine_similarity(embedding1.flatten().unsqueeze(0), embedding2.flatten().unsqueeze(0)) return similarity.item() def compare_embeddings(embedding, database, threshold=0.9): best_match = None best_similarity = threshold for name, db_embeddings in database.items(): for db_embedding in db_embeddings: db_embedding = torch.tensor(db_embedding).to(device) similarity = cosine_similarity(embedding, db_embedding) if similarity > best_similarity: best_match = name best_similarity = similarity if best_match is not None: return best_match, best_similarity return None, None def align_faces(frame): # Convert the frame to a PIL image if it's a numpy array if isinstance(frame, np.ndarray): frame = Image.fromarray(frame) boxes, _ = mtcnn.detect(frame) aligned_faces = [] if boxes is not None: faces = mtcnn(frame) if faces is not None: for face in faces: # Convert the face tensor to PIL Image face = transforms.ToPILImage()(face) aligned_faces.append(face) return aligned_faces, boxes def draw_annotations(frame, detections, names=None): if detections is None: return frame if names is None: names = ["Unknown"] * len(detections) for i, detection in enumerate(detections): x1, y1, x2, y2 = map(int, detection) cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) if names[i]: cv2.putText(frame, names[i], (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2) return frame def process_image(image): start_time = time.time() frame = np.array(image) aligned_faces, boxes = align_faces(frame) names = [] if aligned_faces is not None: for face in aligned_faces: face = transform(face).unsqueeze(0).to(device) with torch.no_grad(): embedding = model(face) name, similarity = compare_embeddings(embedding, database) if name is not None: names.append(f"{name} ({similarity:.2f})") else: names.append("Unknown") annotated_image = draw_annotations(frame, boxes, names) result = "Face recognition complete." else: annotated_image = frame result = "No faces detected." end_time = time.time() inference_time = end_time - start_time result += f" Inference time: {inference_time:.2f} seconds" return annotated_image, result # Create the Gradio interface iface = gr.Interface( fn=process_image, inputs=gr.Image(type="pil"), # Ensure the input type matches what the function expects outputs=[gr.Image(type="numpy"), gr.Textbox()], title="Face Detection and Recognition with MTCNN and ViT", description="Upload an image and the model will detect and recognize faces in it." ) # Launch the interface iface.launch(share=True, debug=True)