dvieri's picture
Update app.py
fc47bd3 verified
import cv2
from PIL import Image
import streamlit as st
import tempfile
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import resnet50
from mtcnn import MTCNN
from skimage.feature import hog
import joblib
import numpy as np
class VGGFaceEmbedding(nn.Module):
def __init__(self):
super(VGGFaceEmbedding, self).__init__()
self.base_model = resnet50(pretrained=True)
self.base_model = nn.Sequential(*list(self.base_model.children())[:-2])
self.pooling = nn.AdaptiveAvgPool2d((1, 1))
self.flatten = nn.Flatten()
def forward(self, x):
x = self.base_model(x)
x = self.pooling(x)
x = self.flatten(x)
return x
class L1Dist(nn.Module):
def __init__(self):
super(L1Dist, self).__init__()
def forward(self, input_embedding, validation_embedding):
return torch.abs(input_embedding - validation_embedding)
class SiameseNetwork(nn.Module):
def __init__(self):
super(SiameseNetwork, self).__init__()
self.embedding = VGGFaceEmbedding()
self.distance = L1Dist()
self.fc1 = nn.Linear(2048, 512)
self.fc2 = nn.Linear(512, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, input_image, validation_image):
input_embedding = self.embedding(input_image)
validation_embedding = self.embedding(validation_image)
distances = self.distance(input_embedding, validation_embedding)
x = self.fc1(distances)
x = self.fc2(x)
x = self.sigmoid(x)
return x
def preprocess_image_siamese(temp_face_path):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
img = Image.open(temp_face_path).convert("RGB")
return transform(img)
def preprocess_image_svm(img):
img = cv2.resize(img, (224, 224))
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
return img
def extract_hog_features(img):
hog_features = hog(img, orientations=9, pixels_per_cell=(16, 16), cells_per_block=(4, 4))
return hog_features
def get_face(img):
detector = MTCNN()
faces = detector.detect_faces(img)
if faces:
x1, y1, w, h = faces[0]['box']
x1, y1 = abs(x1), abs(y1)
x2, y2 = x1 + w, y1 + h
return img[y1:y2, x1:x2]
return None
def verify(image, model, person, validation_image=None, threshold=None):
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_image:
temp_image.write(image.read())
temp_image_path = temp_image.name
image = cv2.imread(temp_image_path)
face = get_face(image)
temp_face_path = tempfile.mktemp(suffix=".jpg")
cv2.imwrite(temp_face_path, face)
if face is not None:
if model == "Siamese":
siamese = SiameseNetwork()
siamese.load_state_dict(torch.load(f'siamese_{person.lower()}.pth'))
siamese.eval()
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as validation_temp_image:
validation_temp_image.write(validation_image.read())
validation_temp_image_path = validation_temp_image.name
validation_image = cv2.imread(validation_temp_image_path)
validation_face = get_face(validation_image)
st.image([face, validation_face], caption=["Face 1", "Face 2"], width=200)
validation_temp_face_path = tempfile.mktemp(suffix=".jpg")
cv2.imwrite(validation_temp_face_path, validation_face)
face = preprocess_image_siamese(temp_face_path)
validation_face = preprocess_image_siamese(validation_temp_image_path)
face = face.unsqueeze(0)
validation_face = validation_face.unsqueeze(0)
with torch.no_grad():
output = siamese(face, validation_face)
probability = output.item()
pred = 1.0 if probability > threshold else 0.0
if pred == 1:
st.write("Match")
else:
st.write("Not Match")
elif model == "HOG-SVM":
with open(f'./svm_{person.lower()}.pkl', 'rb') as f:
svm = joblib.load(f)
with open(f'./pca_{person.lower()}.pkl', 'rb') as f:
pca = joblib.load(f)
face = cv2.imread(temp_face_path)
face = preprocess_image_svm(face)
st.image(face, caption="Face 1", width=200)
hog = extract_hog_features(face)
hog_pca = pca.transform([hog])
pred = svm.predict(hog_pca)
if pred == 1:
st.write("Match")
else:
st.write("Not Match")
def main():
st.title("Face Verification")
person_dict = {
"Theo": 0.542,
"Deverel": 0.5,
"Justin": 0.5
}
model = st.selectbox("Select Model", ["Siamese", "HOG-SVM"])
person = st.selectbox("Select Person", person_dict.keys())
if model == "Siamese":
uploaded_image = st.file_uploader("Upload Validation Image (Siamese)", type=["jpg", "png"])
enable = st.checkbox("Enable camera")
captured_image = st.camera_input("Take a picture", disabled=not enable)
if captured_image and model == "Siamese":
verify(captured_image, model, person, uploaded_image, person_dict.get(person))
elif captured_image and model == "HOG-SVM":
verify(captured_image, model, person)
if __name__ == "__main__":
main()