Spaces:
Running
Running
File size: 4,486 Bytes
c2b7cdc 08b8cf8 c2b7cdc 69ccc49 c2b7cdc 786ce5c c2b7cdc 119036b c2b7cdc 84e2069 c2b7cdc af18329 c2b7cdc 3391f44 0f130a4 3391f44 0f130a4 ba65288 3391f44 28e6e92 3391f44 28e6e92 c2b7cdc 2dc58f5 c2b7cdc af18329 c2b7cdc af18329 c2b7cdc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import cv2
from PIL import Image
import streamlit as st
import tempfile
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import resnet50
from mtcnn import MTCNN
from skimage.feature import hog
import joblib
import numpy as np
class VGGFaceEmbedding(nn.Module):
def __init__(self):
super(VGGFaceEmbedding, self).__init__()
self.base_model = resnet50(pretrained=True)
self.base_model = nn.Sequential(*list(self.base_model.children())[:-2])
self.pooling = nn.AdaptiveAvgPool2d((1, 1))
self.flatten = nn.Flatten()
def forward(self, x):
x = self.base_model(x)
x = self.pooling(x)
x = self.flatten(x)
return x
class L1Dist(nn.Module):
def __init__(self):
super(L1Dist, self).__init__()
def forward(self, input_embedding, validation_embedding):
return torch.abs(input_embedding - validation_embedding)
class SiameseNetwork(nn.Module):
def __init__(self):
super(SiameseNetwork, self).__init__()
self.embedding = VGGFaceEmbedding()
self.distance = L1Dist()
self.fc1 = nn.Linear(2048, 512)
self.fc2 = nn.Linear(512, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, input_image, validation_image):
input_embedding = self.embedding(input_image)
validation_embedding = self.embedding(validation_image)
distances = self.distance(input_embedding, validation_embedding)
x = self.fc1(distances)
x = self.fc2(x)
x = self.sigmoid(x)
return x
def preprocess_image_siamese(img):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
img = Image.open(img).convert("RGB")
return transform(img)
def preprocess_image_svm(img):
img = cv2.resize(img, (224, 224))
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
return img
def extract_hog_features(img):
hog_features = hog(img, orientations=9, pixels_per_cell=(16, 16), cells_per_block=(4, 4))
return hog_features
def get_face(img):
detector = MTCNN()
faces = detector.detect_faces(img)
if faces:
x1, y1, w, h = faces[0]['box']
x1, y1 = abs(x1), abs(y1)
x2, y2 = x1 + w, y1 + h
return img[y1:y2, x1:x2]
return None
def verify(image, model, person):
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_image:
temp_image.write(image.read())
temp_image_path = temp_image.name
image = cv2.imread(temp_image_path)
face = get_face(image)
temp_face_path = tempfile.mktemp(suffix=".jpg")
cv2.imwrite(temp_face_path, face)
if face is not None:
if model == "Siamese":
siamese = SiameseNetwork()
siamese.load_state_dict(torch.load(f'siamese_{person.lower()}.pth'))
siamese.eval()
face = Image.open(temp_face_path)
face = preprocess_image_siamese(face)
with torch.no_grad():
output = model(face)
probability = output.item()
pred = 1.0 if probability > 0.7 else 0.0
if pred == 1:
st.write("Match")
else:
st.write("Not Match")
elif model == "HOG-SVM":
with open(f'./svm_{person.lower()}.pkl', 'rb') as f:
svm = joblib.load(f)
with open(f'./pca_{person.lower()}.pkl', 'rb') as f:
pca = joblib.load(f)
face = cv2.imread(temp_face_path)
face = preprocess_image_svm(face)
hog = extract_hog_features(face)
hog_pca = pca.transform([hog])
pred = svm.predict(hog_pca)
if pred == 1:
st.write("Match")
else:
st.write("Not Match")
def main():
st.title("Face Verification")
model = st.selectbox("Select Model", ["Siamese", "HOG-SVM"])
person = st.selectbox("Select Person", ["Theo"])
enable = st.checkbox("Enable camera")
captured_image = st.camera_input("Take a picture", disabled=not enable)
if captured_image:
verify(captured_image, model, person)
if __name__ == "__main__":
main()
|