Spaces:
Sleeping
Sleeping
add model recognition
Browse files- app.py +99 -49
- requirements.txt +0 -2
app.py
CHANGED
@@ -1,50 +1,71 @@
|
|
1 |
-
|
2 |
-
# import torch
|
3 |
-
# from facenet_pytorch import MTCNN
|
4 |
-
# import pickle
|
5 |
-
# import cv2
|
6 |
-
# from PIL import Image
|
7 |
-
# import numpy as np
|
8 |
-
# from transformers import ViTImageProcessor, ViTModel
|
9 |
-
# import torch.nn as nn
|
10 |
-
# from torchvision import transforms
|
11 |
-
# from streamlit_webrtc import webrtc_streamer, VideoProcessorBase, WebRtcMode
|
12 |
-
# import av
|
13 |
-
|
14 |
-
# class ViT(nn.Module):
|
15 |
-
# def __init__(self, base_model):
|
16 |
-
# super(ViT, self).__init__()
|
17 |
-
# self.base_model = base_model
|
18 |
-
|
19 |
-
# def forward(self, x):
|
20 |
-
# x = self.base_model(x).pooler_output
|
21 |
-
# return x
|
22 |
-
|
23 |
-
# @st.cache_resource
|
24 |
-
# def load_model():
|
25 |
-
# model_name = "google/vit-base-patch16-224"
|
26 |
-
# processor = ViTImageProcessor.from_pretrained(model_name)
|
27 |
-
# base_model = ViTModel.from_pretrained("WinKawaks/vit-small-patch16-224")
|
28 |
-
# model = ViT(base_model)
|
29 |
-
# model.load_state_dict(torch.load('faceViT6.pth', map_location=torch.device('cpu')))
|
30 |
-
# model.eval()
|
31 |
-
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
32 |
-
# model.to(device)
|
33 |
-
# return model, processor, device
|
34 |
-
|
35 |
import torch
|
36 |
from facenet_pytorch import MTCNN
|
|
|
37 |
import cv2
|
38 |
-
import numpy as np
|
39 |
import gradio as gr
|
40 |
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
#
|
43 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
44 |
|
45 |
-
# Initialize MTCNN
|
46 |
mtcnn = MTCNN(keep_all=True, min_face_size=20, thresholds=[0.6, 0.7, 0.7], device=device)
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
def align_faces(frame, mtcnn, device):
|
49 |
boxes, _ = mtcnn.detect(frame)
|
50 |
aligned_faces = []
|
@@ -69,17 +90,46 @@ def draw_annotations(frame, detections, names=None):
|
|
69 |
def process_image(image):
|
70 |
frame = np.array(image)
|
71 |
aligned_faces, boxes = align_faces(frame, mtcnn, device)
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
# Create the Gradio interface
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
#
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch
|
3 |
from facenet_pytorch import MTCNN
|
4 |
+
import pickle
|
5 |
import cv2
|
|
|
6 |
import gradio as gr
|
7 |
from PIL import Image
|
8 |
+
import numpy as np
|
9 |
+
from transformers import ViTImageProcessor, ViTModel
|
10 |
+
import torch.nn as nn
|
11 |
+
from torchvision import transforms
|
12 |
+
|
13 |
+
|
14 |
+
# Define the ViT class
|
15 |
+
class ViT(nn.Module):
|
16 |
+
def __init__(self, base_model):
|
17 |
+
super(ViT, self).__init__()
|
18 |
+
self.base_model = base_model
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
x = self.base_model(x).pooler_output
|
22 |
+
return x
|
23 |
+
|
24 |
+
# Load the model and processor
|
25 |
+
model_name = "google/vit-base-patch16-224"
|
26 |
+
processor = ViTImageProcessor.from_pretrained(model_name)
|
27 |
+
base_model = ViTModel.from_pretrained("WinKawaks/vit-small-patch16-224")
|
28 |
+
model = ViT(base_model)
|
29 |
+
model.load_state_dict(torch.load('faceViT6.pth'))
|
30 |
+
|
31 |
+
# Set the model to evaluation mode
|
32 |
+
model.eval()
|
33 |
|
34 |
+
# Check if CUDA is available and move the model to GPU if it is
|
35 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
36 |
+
model.to(device)
|
37 |
|
38 |
+
# Initialize MTCNN for face detection
|
39 |
mtcnn = MTCNN(keep_all=True, min_face_size=20, thresholds=[0.6, 0.7, 0.7], device=device)
|
40 |
|
41 |
+
# Define the transformation
|
42 |
+
transform = transforms.Compose([
|
43 |
+
transforms.Resize((224, 224)),
|
44 |
+
transforms.ToTensor()
|
45 |
+
])
|
46 |
+
|
47 |
+
# Load the database of embeddings
|
48 |
+
with open('face_database_ViT6.pkl', 'rb') as f:
|
49 |
+
database = pickle.load(f)
|
50 |
+
|
51 |
+
def cosine_similarity(embedding1, embedding2):
|
52 |
+
similarity = torch.nn.functional.cosine_similarity(embedding1.unsqueeze(0), embedding2.unsqueeze(0))
|
53 |
+
return similarity.item()
|
54 |
+
|
55 |
+
def compare_embeddings(embedding, database, threshold=0.9):
|
56 |
+
best_match = None
|
57 |
+
best_similarity = threshold
|
58 |
+
for name, db_embeddings in database.items():
|
59 |
+
for db_embedding in db_embeddings:
|
60 |
+
db_embedding = torch.tensor(db_embedding).to(device)
|
61 |
+
similarity = cosine_similarity(embedding, db_embedding)
|
62 |
+
if similarity > best_similarity:
|
63 |
+
best_match = name
|
64 |
+
best_similarity = similarity
|
65 |
+
if best_match is not None:
|
66 |
+
return best_match, best_similarity
|
67 |
+
return None, None
|
68 |
+
|
69 |
def align_faces(frame, mtcnn, device):
|
70 |
boxes, _ = mtcnn.detect(frame)
|
71 |
aligned_faces = []
|
|
|
90 |
def process_image(image):
|
91 |
frame = np.array(image)
|
92 |
aligned_faces, boxes = align_faces(frame, mtcnn, device)
|
93 |
+
|
94 |
+
if aligned_faces is not None:
|
95 |
+
names = []
|
96 |
+
for face in aligned_faces:
|
97 |
+
face = transform(face)
|
98 |
+
face = face.unsqueeze(0).to(device)
|
99 |
+
with torch.no_grad():
|
100 |
+
embedding = model(face)
|
101 |
+
name, similarity = compare_embeddings(embedding, database)
|
102 |
+
if name is not None:
|
103 |
+
names.append(f"{name} ({similarity:.2f})")
|
104 |
+
else:
|
105 |
+
names.append("Unknown")
|
106 |
+
annotated_image = draw_annotations(frame, boxes, names)
|
107 |
+
result = "Face recognition complete."
|
108 |
+
else:
|
109 |
+
annotated_image = frame
|
110 |
+
result = "No faces detected."
|
111 |
+
|
112 |
+
return annotated_image, result
|
113 |
+
|
114 |
+
def capture_and_process_image(webcam_image):
|
115 |
+
captured_img, result = process_image(webcam_image)
|
116 |
+
return captured_img, result
|
117 |
|
118 |
# Create the Gradio interface
|
119 |
+
with gr.Blocks() as demo:
|
120 |
+
with gr.Row():
|
121 |
+
# Webcam input component
|
122 |
+
webcam_input = gr.Image(source="webcam", streaming=True, label="Webcam Input", height=483)
|
123 |
+
# Captured image display
|
124 |
+
captured_image = gr.Image(label="Captured Image", height=483)
|
125 |
+
# Capture button
|
126 |
+
capture_button = gr.Button("Capture Image")
|
127 |
+
# Result output textbox
|
128 |
+
result_output = gr.Textbox(label="Inference Result")
|
129 |
+
|
130 |
+
# Define the button click action
|
131 |
+
capture_button.click(fn=capture_and_process_image, inputs=webcam_input, outputs=[captured_image, result_output])
|
132 |
+
|
133 |
+
if __name__ == "__main__":
|
134 |
+
# Launch the interface with share=True to create a public link
|
135 |
+
demo.launch(share=True, debug=True)
|
requirements.txt
CHANGED
@@ -2,8 +2,6 @@ torch==2.2.1
|
|
2 |
torchaudio==2.2.1
|
3 |
torchsummary==1.5.1
|
4 |
torchvision==0.17.1
|
5 |
-
streamlit==1.31.1
|
6 |
-
streamlit-webrtc==0.47.6
|
7 |
sympy==1.12
|
8 |
tenacity==8.2.3
|
9 |
tensorboard==2.15.2
|
|
|
2 |
torchaudio==2.2.1
|
3 |
torchsummary==1.5.1
|
4 |
torchvision==0.17.1
|
|
|
|
|
5 |
sympy==1.12
|
6 |
tenacity==8.2.3
|
7 |
tensorboard==2.15.2
|