Spaces:
Running
Running
import cv2 | |
from PIL import Image | |
import streamlit as st | |
import tempfile | |
import torch | |
import torch.nn as nn | |
from torchvision import transforms | |
from torchvision.models import resnet50 | |
from mtcnn import MTCNN | |
from skimage.feature import hog | |
import joblib | |
import numpy as np | |
class VGGFaceEmbedding(nn.Module): | |
def __init__(self): | |
super(VGGFaceEmbedding, self).__init__() | |
self.base_model = resnet50(pretrained=True) | |
self.base_model = nn.Sequential(*list(self.base_model.children())[:-2]) | |
self.pooling = nn.AdaptiveAvgPool2d((1, 1)) | |
self.flatten = nn.Flatten() | |
def forward(self, x): | |
x = self.base_model(x) | |
x = self.pooling(x) | |
x = self.flatten(x) | |
return x | |
class L1Dist(nn.Module): | |
def __init__(self): | |
super(L1Dist, self).__init__() | |
def forward(self, input_embedding, validation_embedding): | |
return torch.abs(input_embedding - validation_embedding) | |
class SiameseNetwork(nn.Module): | |
def __init__(self): | |
super(SiameseNetwork, self).__init__() | |
self.embedding = VGGFaceEmbedding() | |
self.distance = L1Dist() | |
self.fc1 = nn.Linear(2048, 512) | |
self.fc2 = nn.Linear(512, 1) | |
self.sigmoid = nn.Sigmoid() | |
def forward(self, input_image, validation_image): | |
input_embedding = self.embedding(input_image) | |
validation_embedding = self.embedding(validation_image) | |
distances = self.distance(input_embedding, validation_embedding) | |
x = self.fc1(distances) | |
x = self.fc2(x) | |
x = self.sigmoid(x) | |
return x | |
def preprocess_image_siamese(temp_face_path): | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor() | |
]) | |
img = Image.open(temp_face_path).convert("RGB") | |
return transform(img) | |
def preprocess_image_svm(img): | |
img = cv2.resize(img, (224, 224)) | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
return img | |
def extract_hog_features(img): | |
hog_features = hog(img, orientations=9, pixels_per_cell=(16, 16), cells_per_block=(4, 4)) | |
return hog_features | |
def get_face(img): | |
detector = MTCNN() | |
faces = detector.detect_faces(img) | |
if faces: | |
x1, y1, w, h = faces[0]['box'] | |
x1, y1 = abs(x1), abs(y1) | |
x2, y2 = x1 + w, y1 + h | |
return img[y1:y2, x1:x2] | |
return None | |
def verify(image, model, person, validation_image=None, threshold=None): | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_image: | |
temp_image.write(image.read()) | |
temp_image_path = temp_image.name | |
image = cv2.imread(temp_image_path) | |
face = get_face(image) | |
temp_face_path = tempfile.mktemp(suffix=".jpg") | |
cv2.imwrite(temp_face_path, face) | |
if face is not None: | |
if model == "Siamese": | |
siamese = SiameseNetwork() | |
siamese.load_state_dict(torch.load(f'siamese_{person.lower()}.pth')) | |
siamese.eval() | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as validation_temp_image: | |
validation_temp_image.write(validation_image.read()) | |
validation_temp_image_path = validation_temp_image.name | |
validation_image = cv2.imread(validation_temp_image_path) | |
validation_face = get_face(validation_image) | |
st.image([face, validation_face], caption=["Face 1", "Face 2"], width=200) | |
validation_temp_face_path = tempfile.mktemp(suffix=".jpg") | |
cv2.imwrite(validation_temp_face_path, validation_face) | |
face = preprocess_image_siamese(temp_face_path) | |
validation_face = preprocess_image_siamese(validation_temp_image_path) | |
face = face.unsqueeze(0) | |
validation_face = validation_face.unsqueeze(0) | |
with torch.no_grad(): | |
output = siamese(face, validation_face) | |
probability = output.item() | |
pred = 1.0 if probability > threshold else 0.0 | |
if pred == 1: | |
st.write("Match") | |
else: | |
st.write("Not Match") | |
elif model == "HOG-SVM": | |
with open(f'./svm_{person.lower()}.pkl', 'rb') as f: | |
svm = joblib.load(f) | |
with open(f'./pca_{person.lower()}.pkl', 'rb') as f: | |
pca = joblib.load(f) | |
face = cv2.imread(temp_face_path) | |
face = preprocess_image_svm(face) | |
st.image(face, caption="Face 1", width=200) | |
hog = extract_hog_features(face) | |
hog_pca = pca.transform([hog]) | |
pred = svm.predict(hog_pca) | |
if pred == 1: | |
st.write("Match") | |
else: | |
st.write("Not Match") | |
def main(): | |
st.title("Face Verification") | |
person_dict = { | |
"Theo": 0.542, | |
"Deverel": 0.5, | |
"Justin": 0.5 | |
} | |
model = st.selectbox("Select Model", ["Siamese", "HOG-SVM"]) | |
person = st.selectbox("Select Person", person_dict.keys()) | |
if model == "Siamese": | |
uploaded_image = st.file_uploader("Upload Validation Image (Siamese)", type=["jpg", "png"]) | |
enable = st.checkbox("Enable camera") | |
captured_image = st.camera_input("Take a picture", disabled=not enable) | |
if captured_image and model == "Siamese": | |
verify(captured_image, model, person, uploaded_image, person_dict.get(person)) | |
elif captured_image and model == "HOG-SVM": | |
verify(captured_image, model, person) | |
if __name__ == "__main__": | |
main() | |