import cv2 from PIL import Image import streamlit as st import tempfile import torch import torch.nn as nn from torchvision import transforms from torchvision.models import resnet50 from mtcnn import MTCNN from skimage.feature import hog import joblib import numpy as np class VGGFaceEmbedding(nn.Module): def __init__(self): super(VGGFaceEmbedding, self).__init__() self.base_model = resnet50(pretrained=True) self.base_model = nn.Sequential(*list(self.base_model.children())[:-2]) self.pooling = nn.AdaptiveAvgPool2d((1, 1)) self.flatten = nn.Flatten() def forward(self, x): x = self.base_model(x) x = self.pooling(x) x = self.flatten(x) return x class L1Dist(nn.Module): def __init__(self): super(L1Dist, self).__init__() def forward(self, input_embedding, validation_embedding): return torch.abs(input_embedding - validation_embedding) class SiameseNetwork(nn.Module): def __init__(self): super(SiameseNetwork, self).__init__() self.embedding = VGGFaceEmbedding() self.distance = L1Dist() self.fc1 = nn.Linear(2048, 512) self.fc2 = nn.Linear(512, 1) self.sigmoid = nn.Sigmoid() def forward(self, input_image, validation_image): input_embedding = self.embedding(input_image) validation_embedding = self.embedding(validation_image) distances = self.distance(input_embedding, validation_embedding) x = self.fc1(distances) x = self.fc2(x) x = self.sigmoid(x) return x def preprocess_image_siamese(temp_face_path): transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) img = Image.open(temp_face_path).convert("RGB") return transform(img) def preprocess_image_svm(img): img = cv2.resize(img, (224, 224)) img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) return img def extract_hog_features(img): hog_features = hog(img, orientations=9, pixels_per_cell=(16, 16), cells_per_block=(4, 4)) return hog_features def get_face(img): detector = MTCNN() faces = detector.detect_faces(img) if faces: x1, y1, w, h = faces[0]['box'] x1, y1 = abs(x1), abs(y1) x2, y2 = x1 + w, y1 + h return img[y1:y2, x1:x2] return None def verify(image, model, person, validation_image=None, threshold=None): with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_image: temp_image.write(image.read()) temp_image_path = temp_image.name image = cv2.imread(temp_image_path) face = get_face(image) temp_face_path = tempfile.mktemp(suffix=".jpg") cv2.imwrite(temp_face_path, face) if face is not None: if model == "Siamese": siamese = SiameseNetwork() siamese.load_state_dict(torch.load(f'siamese_{person.lower()}.pth')) siamese.eval() with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as validation_temp_image: validation_temp_image.write(validation_image.read()) validation_temp_image_path = validation_temp_image.name validation_image = cv2.imread(validation_temp_image_path) validation_face = get_face(validation_image) st.image([face, validation_face], caption=["Face 1", "Face 2"], width=200) validation_temp_face_path = tempfile.mktemp(suffix=".jpg") cv2.imwrite(validation_temp_face_path, validation_face) face = preprocess_image_siamese(temp_face_path) validation_face = preprocess_image_siamese(validation_temp_image_path) face = face.unsqueeze(0) validation_face = validation_face.unsqueeze(0) with torch.no_grad(): output = siamese(face, validation_face) probability = output.item() pred = 1.0 if probability > threshold else 0.0 if pred == 1: st.write("Match") else: st.write("Not Match") elif model == "HOG-SVM": with open(f'./svm_{person.lower()}.pkl', 'rb') as f: svm = joblib.load(f) with open(f'./pca_{person.lower()}.pkl', 'rb') as f: pca = joblib.load(f) face = cv2.imread(temp_face_path) face = preprocess_image_svm(face) st.image(face, caption="Face 1", width=200) hog = extract_hog_features(face) hog_pca = pca.transform([hog]) pred = svm.predict(hog_pca) if pred == 1: st.write("Match") else: st.write("Not Match") def main(): st.title("Face Verification") person_dict = { "Theo": 0.542, "Deverel": 0.5, "Justin": 0.5 } model = st.selectbox("Select Model", ["Siamese", "HOG-SVM"]) person = st.selectbox("Select Person", person_dict.keys()) if model == "Siamese": uploaded_image = st.file_uploader("Upload Validation Image (Siamese)", type=["jpg", "png"]) enable = st.checkbox("Enable camera") captured_image = st.camera_input("Take a picture", disabled=not enable) if captured_image and model == "Siamese": verify(captured_image, model, person, uploaded_image, person_dict.get(person)) elif captured_image and model == "HOG-SVM": verify(captured_image, model, person) if __name__ == "__main__": main()