hasnanmr commited on
Commit
d621e56
·
1 Parent(s): 07ecf45

add model recognition

Browse files
Files changed (2) hide show
  1. app.py +99 -49
  2. requirements.txt +0 -2
app.py CHANGED
@@ -1,50 +1,71 @@
1
- # import streamlit as st
2
- # import torch
3
- # from facenet_pytorch import MTCNN
4
- # import pickle
5
- # import cv2
6
- # from PIL import Image
7
- # import numpy as np
8
- # from transformers import ViTImageProcessor, ViTModel
9
- # import torch.nn as nn
10
- # from torchvision import transforms
11
- # from streamlit_webrtc import webrtc_streamer, VideoProcessorBase, WebRtcMode
12
- # import av
13
-
14
- # class ViT(nn.Module):
15
- # def __init__(self, base_model):
16
- # super(ViT, self).__init__()
17
- # self.base_model = base_model
18
-
19
- # def forward(self, x):
20
- # x = self.base_model(x).pooler_output
21
- # return x
22
-
23
- # @st.cache_resource
24
- # def load_model():
25
- # model_name = "google/vit-base-patch16-224"
26
- # processor = ViTImageProcessor.from_pretrained(model_name)
27
- # base_model = ViTModel.from_pretrained("WinKawaks/vit-small-patch16-224")
28
- # model = ViT(base_model)
29
- # model.load_state_dict(torch.load('faceViT6.pth', map_location=torch.device('cpu')))
30
- # model.eval()
31
- # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
- # model.to(device)
33
- # return model, processor, device
34
-
35
  import torch
36
  from facenet_pytorch import MTCNN
 
37
  import cv2
38
- import numpy as np
39
  import gradio as gr
40
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- # Set the device
43
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
44
 
45
- # Initialize MTCNN
46
  mtcnn = MTCNN(keep_all=True, min_face_size=20, thresholds=[0.6, 0.7, 0.7], device=device)
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def align_faces(frame, mtcnn, device):
49
  boxes, _ = mtcnn.detect(frame)
50
  aligned_faces = []
@@ -69,17 +90,46 @@ def draw_annotations(frame, detections, names=None):
69
  def process_image(image):
70
  frame = np.array(image)
71
  aligned_faces, boxes = align_faces(frame, mtcnn, device)
72
- annotated_image = draw_annotations(frame, boxes)
73
- return annotated_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  # Create the Gradio interface
76
- iface = gr.Interface(
77
- fn=process_image,
78
- inputs=gr.Image(type="pil"),
79
- outputs=gr.Image(type="numpy"),
80
- title="Face Detection with MTCNN",
81
- description="Upload an image and the model will detect and align faces in it."
82
- )
83
-
84
- # Launch the interface
85
- iface.launch(share=True, debug=True)
 
 
 
 
 
 
 
 
1
+ import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  from facenet_pytorch import MTCNN
4
+ import pickle
5
  import cv2
 
6
  import gradio as gr
7
  from PIL import Image
8
+ import numpy as np
9
+ from transformers import ViTImageProcessor, ViTModel
10
+ import torch.nn as nn
11
+ from torchvision import transforms
12
+
13
+
14
+ # Define the ViT class
15
+ class ViT(nn.Module):
16
+ def __init__(self, base_model):
17
+ super(ViT, self).__init__()
18
+ self.base_model = base_model
19
+
20
+ def forward(self, x):
21
+ x = self.base_model(x).pooler_output
22
+ return x
23
+
24
+ # Load the model and processor
25
+ model_name = "google/vit-base-patch16-224"
26
+ processor = ViTImageProcessor.from_pretrained(model_name)
27
+ base_model = ViTModel.from_pretrained("WinKawaks/vit-small-patch16-224")
28
+ model = ViT(base_model)
29
+ model.load_state_dict(torch.load('faceViT6.pth'))
30
+
31
+ # Set the model to evaluation mode
32
+ model.eval()
33
 
34
+ # Check if CUDA is available and move the model to GPU if it is
35
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
36
+ model.to(device)
37
 
38
+ # Initialize MTCNN for face detection
39
  mtcnn = MTCNN(keep_all=True, min_face_size=20, thresholds=[0.6, 0.7, 0.7], device=device)
40
 
41
+ # Define the transformation
42
+ transform = transforms.Compose([
43
+ transforms.Resize((224, 224)),
44
+ transforms.ToTensor()
45
+ ])
46
+
47
+ # Load the database of embeddings
48
+ with open('face_database_ViT6.pkl', 'rb') as f:
49
+ database = pickle.load(f)
50
+
51
+ def cosine_similarity(embedding1, embedding2):
52
+ similarity = torch.nn.functional.cosine_similarity(embedding1.unsqueeze(0), embedding2.unsqueeze(0))
53
+ return similarity.item()
54
+
55
+ def compare_embeddings(embedding, database, threshold=0.9):
56
+ best_match = None
57
+ best_similarity = threshold
58
+ for name, db_embeddings in database.items():
59
+ for db_embedding in db_embeddings:
60
+ db_embedding = torch.tensor(db_embedding).to(device)
61
+ similarity = cosine_similarity(embedding, db_embedding)
62
+ if similarity > best_similarity:
63
+ best_match = name
64
+ best_similarity = similarity
65
+ if best_match is not None:
66
+ return best_match, best_similarity
67
+ return None, None
68
+
69
  def align_faces(frame, mtcnn, device):
70
  boxes, _ = mtcnn.detect(frame)
71
  aligned_faces = []
 
90
  def process_image(image):
91
  frame = np.array(image)
92
  aligned_faces, boxes = align_faces(frame, mtcnn, device)
93
+
94
+ if aligned_faces is not None:
95
+ names = []
96
+ for face in aligned_faces:
97
+ face = transform(face)
98
+ face = face.unsqueeze(0).to(device)
99
+ with torch.no_grad():
100
+ embedding = model(face)
101
+ name, similarity = compare_embeddings(embedding, database)
102
+ if name is not None:
103
+ names.append(f"{name} ({similarity:.2f})")
104
+ else:
105
+ names.append("Unknown")
106
+ annotated_image = draw_annotations(frame, boxes, names)
107
+ result = "Face recognition complete."
108
+ else:
109
+ annotated_image = frame
110
+ result = "No faces detected."
111
+
112
+ return annotated_image, result
113
+
114
+ def capture_and_process_image(webcam_image):
115
+ captured_img, result = process_image(webcam_image)
116
+ return captured_img, result
117
 
118
  # Create the Gradio interface
119
+ with gr.Blocks() as demo:
120
+ with gr.Row():
121
+ # Webcam input component
122
+ webcam_input = gr.Image(source="webcam", streaming=True, label="Webcam Input", height=483)
123
+ # Captured image display
124
+ captured_image = gr.Image(label="Captured Image", height=483)
125
+ # Capture button
126
+ capture_button = gr.Button("Capture Image")
127
+ # Result output textbox
128
+ result_output = gr.Textbox(label="Inference Result")
129
+
130
+ # Define the button click action
131
+ capture_button.click(fn=capture_and_process_image, inputs=webcam_input, outputs=[captured_image, result_output])
132
+
133
+ if __name__ == "__main__":
134
+ # Launch the interface with share=True to create a public link
135
+ demo.launch(share=True, debug=True)
requirements.txt CHANGED
@@ -2,8 +2,6 @@ torch==2.2.1
2
  torchaudio==2.2.1
3
  torchsummary==1.5.1
4
  torchvision==0.17.1
5
- streamlit==1.31.1
6
- streamlit-webrtc==0.47.6
7
  sympy==1.12
8
  tenacity==8.2.3
9
  tensorboard==2.15.2
 
2
  torchaudio==2.2.1
3
  torchsummary==1.5.1
4
  torchvision==0.17.1
 
 
5
  sympy==1.12
6
  tenacity==8.2.3
7
  tensorboard==2.15.2