varma123 commited on
Commit
4340122
·
verified ·
1 Parent(s): 117711f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -4
app.py CHANGED
@@ -1,7 +1,85 @@
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import Libraries
2
  import gradio as gr
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from facenet_pytorch import MTCNN, InceptionResnetV1
6
+ import numpy as np
7
+ from PIL import Image
8
+ import cv2
9
+ from pytorch_grad_cam import GradCAM
10
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
11
+ from pytorch_grad_cam.utils.image import show_cam_on_image
12
+ import warnings
13
+ warnings.filterwarnings("ignore")
14
 
15
+ # Download and Load Model
16
+ DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
17
 
18
+ mtcnn = MTCNN(
19
+ select_largest=False,
20
+ post_process=False,
21
+ device=DEVICE
22
+ ).to(DEVICE).eval()
23
+ model = InceptionResnetV1(
24
+ pretrained="vggface2",
25
+ classify=True,
26
+ num_classes=1,
27
+ device=DEVICE
28
+ )
29
+
30
+ checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu'))
31
+ model.load_state_dict(checkpoint['model_state_dict'])
32
+ model.to(DEVICE)
33
+ model.eval()
34
+ # Model Inference
35
+ def predict(input_image:Image.Image):
36
+ """Predict the label of the input_image"""
37
+ face = mtcnn(input_image)
38
+ if face is None:
39
+ raise Exception('No face detected')
40
+ face = face.unsqueeze(0) # add the batch dimension
41
+ face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
42
+
43
+ # convert the face into a numpy array to be able to plot it
44
+ prev_face = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()
45
+ prev_face = prev_face.astype('uint8')
46
+
47
+ face = face.to(DEVICE)
48
+ face = face.to(torch.float32)
49
+ face = face / 255.0
50
+ face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()
51
+
52
+ target_layers=[model.block8.branch1[-1]]
53
+ use_cuda = True if torch.cuda.is_available() else False
54
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=use_cuda)
55
+ targets = [ClassifierOutputTarget(0)]
56
+
57
+ grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True)
58
+ grayscale_cam = grayscale_cam[0, :]
59
+ visualization = show_cam_on_image(face_image_to_plot, grayscale_cam, use_rgb=True)
60
+ face_with_mask = cv2.addWeighted(prev_face, 1, visualization, 0.5, 0)
61
+
62
+ with torch.no_grad():
63
+ output = torch.sigmoid(model(face).squeeze(0))
64
+ prediction = "real" if output.item() < 0.5 else "fake"
65
+
66
+ real_prediction = 1 - output.item()
67
+ fake_prediction = output.item()
68
+
69
+ confidences = {
70
+ 'real': real_prediction,
71
+ 'fake': fake_prediction
72
+ }
73
+ return confidences, face_with_mask
74
+
75
+ # Gradio Interface
76
+ interface = gr.Interface(
77
+ fn=predict,
78
+ inputs=[
79
+ gr.inputs.Image(label="Input Image", type="pil")
80
+ ],
81
+ outputs=[
82
+ gr.outputs.Label(label="Class"),
83
+ gr.outputs.Image(label="Face with Explainability", type="pil")
84
+ ],
85
+ ).launch()