Spaces:
Running
Running
File size: 4,272 Bytes
c2b7cdc 69ccc49 c2b7cdc 119036b c2b7cdc af18329 c2b7cdc af18329 c2b7cdc af18329 119036b cd4a984 c2b7cdc cd4a984 c2b7cdc af18329 c2b7cdc af18329 c2b7cdc af18329 c2b7cdc af18329 c2b7cdc af18329 c2b7cdc af18329 c2b7cdc 4462db7 c2b7cdc af18329 c2b7cdc af18329 c2b7cdc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import cv2
import streamlit as st
import tempfile
import torch
import torch.nn as nn
from torchvision import transforms
from mtcnn import MTCNN
from skimage.feature import hog
import joblib
import numpy as np
class VGGFaceEmbedding(nn.Module):
def __init__(self):
super(VGGFaceEmbedding, self).__init__()
self.base_model = resnet50(pretrained=True)
self.base_model = nn.Sequential(*list(self.base_model.children())[:-2])
self.pooling = nn.AdaptiveAvgPool2d((1, 1))
self.flatten = nn.Flatten()
def forward(self, x):
x = self.base_model(x)
x = self.pooling(x)
x = self.flatten(x)
return x
class L1Dist(nn.Module):
def __init__(self):
super(L1Dist, self).__init__()
def forward(self, input_embedding, validation_embedding):
return torch.abs(input_embedding - validation_embedding)
class SiameseNetwork(nn.Module):
def __init__(self):
super(SiameseNetwork, self).__init__()
self.embedding = VGGFaceEmbedding()
self.distance = L1Dist()
self.fc1 = nn.Linear(2048, 512)
self.fc2 = nn.Linear(512, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, input_image, validation_image):
input_embedding = self.embedding(input_image)
validation_embedding = self.embedding(validation_image)
distances = self.distance(input_embedding, validation_embedding)
x = self.fc1(distances)
x = self.fc2(x)
x = self.sigmoid(x)
return x
def preprocess_image_siamese(img):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return transform(img)
def preprocess_image_svm(img):
img = cv2.resize(img, (224, 224))
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
return img
def extract_hog_features(img):
hog_features = hog(img, orientations=9, pixels_per_cell=(16, 16), cells_per_block=(4, 4))
return hog_features
def get_face(img):
detector = MTCNN()
faces = detector.detect_faces(img)
if faces:
x1, y1, w, h = faces[0]['box']
x1, y1 = abs(x1), abs(y1)
x2, y2 = x1 + w, y1 + h
return img[y1:y2, x1:x2]
return None
def verify(image, model, person):
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_image:
temp_image.write(image.read())
temp_image_path = temp_image.name
image = cv2.imread(temp_image_path)
face = get_face(image)
if face is not None:
if model == "Siamese":
siamese = SiameseNetwork()
siamese.load_state_dict(torch.load(f'siamese_{person.lower()}.pth'))
face = preprocess_image_siamese(face)
# Move to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
face = face.to(device)
with torch.no_grad():
output = model(face)
probability = output.item()
pred = 1.0 if probability > 0.5 else 0.0
if pred == 1:
st.write("Match")
else:
st.write("Not Match")
elif model == "HOG-SVM":
with open(f'./svm_{person.lower()}.pkl', 'rb') as f:
svm = joblib.load(f)
with open(f'./pca_{person.lower()}.pkl', 'rb') as f:
pca = joblib.load(f)
face = preprocess_image_svm(face)
hog = extract_hog_features(face)
hog_pca = pca.transform([hog])
pred = svm.predict(hog_pca)
if pred == 1:
st.write("Match")
else:
st.write("Not Match")
else:
st.write("Face not detected")
def main():
st.title("Real-time Face Verification App")
model = st.selectbox("Select Model", ["Siamese", "HOG-SVM"])
person = st.selectbox("Select Person", ["Theo"])
enable = st.checkbox("Enable camera")
captured_image = st.camera_input("Take a picture", disabled=not enable)
if captured_image:
verify(captured_image, model, person)
if __name__ == "__main__":
main()
|