varma123 commited on
Commit
64a53bd
·
verified ·
1 Parent(s): 97b144d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -29
app.py CHANGED
@@ -2,19 +2,18 @@ import gradio as gr
2
  import torch
3
  import torch.nn.functional as F
4
  from facenet_pytorch import MTCNN, InceptionResnetV1
5
- import numpy as np
6
  import cv2
7
  from pytorch_grad_cam import GradCAM
8
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
9
  from pytorch_grad_cam.utils.image import show_cam_on_image
10
- from torchvision import transforms
11
  from PIL import Image
 
12
  import warnings
13
 
14
  warnings.filterwarnings("ignore")
15
 
16
  # Download and Load Model
17
- DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
18
 
19
  mtcnn = MTCNN(
20
  select_largest=False,
@@ -25,49 +24,58 @@ mtcnn = MTCNN(
25
  model = InceptionResnetV1(
26
  pretrained="vggface2",
27
  classify=True,
28
- num_classes=1,
29
  device=DEVICE
30
- )
31
 
32
  checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu'))
33
  model.load_state_dict(checkpoint['model_state_dict'])
34
  model.to(DEVICE)
35
- model.eval()
36
 
37
- # Model Inference
38
  def predict_frame(frame):
39
- """Predict whether the input frame contains real or fake faces"""
40
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
41
  frame_pil = Image.fromarray(frame)
42
 
43
  face = mtcnn(frame_pil)
44
  if face is None:
45
- raise Exception('No face detected')
46
- face = face.unsqueeze(0) # add the batch dimension
47
- face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
48
 
 
 
49
  face = face.to(DEVICE, dtype=torch.float32) / 255.0
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  target_layers = [model.block8.branch1[-1]]
52
  use_cuda = True if torch.cuda.is_available() else False
53
  cam = GradCAM(model=model, target_layers=target_layers, use_cuda=use_cuda)
54
  targets = [ClassifierOutputTarget(0)]
55
-
56
  grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True)
57
  grayscale_cam = grayscale_cam[0, :]
58
- visualization = show_cam_on_image(frame, grayscale_cam, use_rgb=True)
59
- face_with_mask = cv2.addWeighted(frame, 1, visualization, 0.5, 0)
60
-
61
- with torch.no_grad():
62
- output = torch.sigmoid(model(face).squeeze(0))
63
- prediction = "real" if output.item() < 0.5 else "fake"
64
 
65
  return prediction, face_with_mask
66
 
67
-
68
- # Function to process video
69
  def predict_video(input_video):
70
  cap = cv2.VideoCapture(input_video)
 
71
  frames = []
72
  confidences = []
73
 
@@ -82,18 +90,12 @@ def predict_video(input_video):
82
  confidences.append(prediction)
83
 
84
  cap.release()
85
- list=[]
86
- list.append(set(confidences))
87
- if( 'fake' in list):
88
- final_prediction='fake'
89
- else:
90
- final_prediction='real'
91
  # Determine the final prediction based on the maximum occurrence of predictions
92
- # final_prediction = max(set(confidences), key=confidences.count)
93
 
94
  return final_prediction, frames
95
 
96
-
97
  # Gradio Interface
98
  interface = gr.Interface(
99
  fn=predict_video,
@@ -102,8 +104,10 @@ interface = gr.Interface(
102
  ],
103
  outputs=[
104
  gr.Label(label="Class"),
105
- gr.Video(label="Face with Explainability")
106
  ],
 
 
107
  )
108
 
109
  interface.launch()
 
2
  import torch
3
  import torch.nn.functional as F
4
  from facenet_pytorch import MTCNN, InceptionResnetV1
 
5
  import cv2
6
  from pytorch_grad_cam import GradCAM
7
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
8
  from pytorch_grad_cam.utils.image import show_cam_on_image
 
9
  from PIL import Image
10
+ import numpy as np
11
  import warnings
12
 
13
  warnings.filterwarnings("ignore")
14
 
15
  # Download and Load Model
16
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
17
 
18
  mtcnn = MTCNN(
19
  select_largest=False,
 
24
  model = InceptionResnetV1(
25
  pretrained="vggface2",
26
  classify=True,
27
+ num_classes=2, # Change to 2 classes (real or fake)
28
  device=DEVICE
29
+ ).eval()
30
 
31
  checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu'))
32
  model.load_state_dict(checkpoint['model_state_dict'])
33
  model.to(DEVICE)
 
34
 
35
+ # Model Inference
36
  def predict_frame(frame):
37
+ """Predict whether the input frame contains a real or fake face"""
38
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
39
  frame_pil = Image.fromarray(frame)
40
 
41
  face = mtcnn(frame_pil)
42
  if face is None:
43
+ return None, None # No face detected
 
 
44
 
45
+ # Preprocess the face
46
+ face = F.interpolate(face.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False)
47
  face = face.to(DEVICE, dtype=torch.float32) / 255.0
48
 
49
+ # Predict
50
+ with torch.no_grad():
51
+ output = torch.sigmoid(model(face).squeeze(0))
52
+ prediction = "real" if output.item() < 0.5 else "fake"
53
+
54
+ # Confidence scores
55
+ real_prediction = 1 - output.item()
56
+ fake_prediction = output.item()
57
+
58
+ confidences = {
59
+ 'real': real_prediction,
60
+ 'fake': fake_prediction
61
+ }
62
+
63
+ # Visualize
64
  target_layers = [model.block8.branch1[-1]]
65
  use_cuda = True if torch.cuda.is_available() else False
66
  cam = GradCAM(model=model, target_layers=target_layers, use_cuda=use_cuda)
67
  targets = [ClassifierOutputTarget(0)]
 
68
  grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True)
69
  grayscale_cam = grayscale_cam[0, :]
70
+ face_np = face.squeeze(0).permute(1, 2, 0).cpu().numpy()
71
+ visualization = show_cam_on_image(face_np, grayscale_cam, use_rgb=True)
72
+ face_with_mask = cv2.addWeighted((face_np * 255).astype(np.uint8), 1, (visualization * 255).astype(np.uint8), 0.5, 0)
 
 
 
73
 
74
  return prediction, face_with_mask
75
 
 
 
76
  def predict_video(input_video):
77
  cap = cv2.VideoCapture(input_video)
78
+
79
  frames = []
80
  confidences = []
81
 
 
90
  confidences.append(prediction)
91
 
92
  cap.release()
93
+
 
 
 
 
 
94
  # Determine the final prediction based on the maximum occurrence of predictions
95
+ final_prediction = 'fake' if confidences.count('fake') > confidences.count('real') else 'real'
96
 
97
  return final_prediction, frames
98
 
 
99
  # Gradio Interface
100
  interface = gr.Interface(
101
  fn=predict_video,
 
104
  ],
105
  outputs=[
106
  gr.Label(label="Class"),
107
+ gr.Image(label="Face with Explainability", type="numpy")
108
  ],
109
+ title="Video Face Authentication",
110
+ description="Detect whether the faces in the video are real or fake."
111
  )
112
 
113
  interface.launch()