dvieri commited on
Commit
119036b
·
verified ·
1 Parent(s): cd4a984

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -1
app.py CHANGED
@@ -8,6 +8,45 @@ from skimage.feature import hog
8
  import joblib
9
  import numpy as np
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def preprocess_image_siamese(img):
12
  transform = transforms.Compose([
13
  transforms.Resize((224, 224)),
@@ -46,7 +85,28 @@ def verify(image, model, person):
46
  face = get_face(image)
47
 
48
  if face is not None:
49
- if model == "HOG-SVM":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  with open(f'./svm_{person.lower()}.pkl', 'rb') as f:
51
  svm = joblib.load(f)
52
  with open(f'./pca_{person.lower()}.pkl', 'rb') as f:
 
8
  import joblib
9
  import numpy as np
10
 
11
+ class VGGFaceEmbedding(nn.Module):
12
+ def __init__(self):
13
+ super(VGGFaceEmbedding, self).__init__()
14
+ self.base_model = resnet50(pretrained=True)
15
+ self.base_model = nn.Sequential(*list(self.base_model.children())[:-2])
16
+ self.pooling = nn.AdaptiveAvgPool2d((1, 1))
17
+ self.flatten = nn.Flatten()
18
+
19
+ def forward(self, x):
20
+ x = self.base_model(x)
21
+ x = self.pooling(x)
22
+ x = self.flatten(x)
23
+ return x
24
+
25
+ class L1Dist(nn.Module):
26
+ def __init__(self):
27
+ super(L1Dist, self).__init__()
28
+
29
+ def forward(self, input_embedding, validation_embedding):
30
+ return torch.abs(input_embedding - validation_embedding)
31
+
32
+ class SiameseNetwork(nn.Module):
33
+ def __init__(self):
34
+ super(SiameseNetwork, self).__init__()
35
+ self.embedding = VGGFaceEmbedding()
36
+ self.distance = L1Dist()
37
+ self.fc1 = nn.Linear(2048, 512)
38
+ self.fc2 = nn.Linear(512, 1)
39
+ self.sigmoid = nn.Sigmoid()
40
+
41
+ def forward(self, input_image, validation_image):
42
+ input_embedding = self.embedding(input_image)
43
+ validation_embedding = self.embedding(validation_image)
44
+ distances = self.distance(input_embedding, validation_embedding)
45
+ x = self.fc1(distances)
46
+ x = self.fc2(x)
47
+ x = self.sigmoid(x)
48
+ return x
49
+
50
  def preprocess_image_siamese(img):
51
  transform = transforms.Compose([
52
  transforms.Resize((224, 224)),
 
85
  face = get_face(image)
86
 
87
  if face is not None:
88
+ if model == "Siamese":
89
+ siamese = SiameseNetwork()
90
+ siamese.load_state_dict(torch.load(f'siamese_{person.lower()}.pth'))
91
+
92
+ face = preprocess_image_siamese(face)
93
+
94
+ # Move to device
95
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
96
+ model.to(device)
97
+ face = face.to(device)
98
+
99
+ with torch.no_grad():
100
+ output = model(face)
101
+ probability = output.item()
102
+ pred = 1.0 if probability > 0.5 else 0.0
103
+
104
+ if pred == 1:
105
+ st.write("Match")
106
+ else:
107
+ st.write("Not Match")
108
+
109
+ elif model == "HOG-SVM":
110
  with open(f'./svm_{person.lower()}.pkl', 'rb') as f:
111
  svm = joblib.load(f)
112
  with open(f'./pca_{person.lower()}.pkl', 'rb') as f: