File size: 4,660 Bytes
783053f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d025a55
 
 
 
 
 
 
 
 
783053f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from transformers import ViTForImageClassification, ViTFeatureExtractor
from fake_face_detection.metrics.make_predictions import get_attention
from torchvision import transforms
import streamlit as st
from PIL import Image
import numpy as np
import pickle
import torch
import cv2

# set the color of the header
def header(text):
    st.markdown(f"<h1 style = 'color: #4B4453; text-align: center'>{text}</h1>", unsafe_allow_html=True)
    st.markdown("""---""")

# initialize the size
size = (224, 224)

# let us add a header
header("FAKE AND REAL FACE DETECTION")

# let us add an expander to write some description of the application
expander = st.expander('Description', expanded=True)

with expander:
    st.write('''This website has for purpose to help internet users
             to know if an profil is safe by verifying if the face
             display on it is verifiable. You can download the image
             of a person in Facebook, Whatsapp or any other social media
             and add in here and click on the submit button to obtain
             the result (fake or real). You will also obtain an
             modification of the original image indicating which
             part of it is suspect or make the site identify if the
             image is real. Enjoy!''')

# let us initialize two columns
left, mid, right = st.columns(3)

# the following function will load the model (must be in cache)
@st.cache_resource
def get_model():
    
    # let us load the image characteristics
    with open('data/extractions/fake_real_dict.txt', 'rb') as f:
        
        depick = pickle.Unpickler(f)
        
        characs = depick.load()
    
    # define the model name
    model_name = 'google/vit-base-patch16-224-in21k'
    
    # recuperate the model
    model = ViTForImageClassification.from_pretrained(
        'data\checkpoints\model_2yW4AcqNIb6zLKNIb6zLK',
        num_labels = len(characs['ids']),
        id2label = {name: key for key, name in characs['ids'].items()},
        label2id = characs['ids']
    )
    
    # recuperate the feature_extractor
    feature_extractor = ViTFeatureExtractor(model_name)
    
    return model, feature_extractor

# let us add a file uploader
st.subheader("Choose an image to inspect")
file = st.file_uploader("", type='jpg')

# if the file is correctly uploaded make the next processes
if file is not None:
    
    # convert the file to an opencv image
    file_bytes = np.asarray(bytearray(file.read()), dtype=np.uint8)
    
    opencv_image = cv2.imdecode(file_bytes, 1)
    
    # resize the image
    opencv_image = cv2.resize(opencv_image, size)
    
    # Let us display the image
    left.header("Loaded image")
    
    left.image(opencv_image, channels='BGR')
    
    left.markdown("""---""")
    
    if left.button("SUBMIT"):
        
        # Let us convert the image format to 'RGB'
        image = cv2.cvtColor(opencv_image, cv2.COLOR_BGR2RGB)
        
        # Let us convert from opencv image to pil image
        image = Image.fromarray(image)
        
        with torch.no_grad():
            
            # Recuperate the model and the feature extractor
            model, feature_extractor = get_model()
            
            # Change to evaluation mode
            _ = model.eval()
            
            # Apply transformation on the image
            image_ = feature_extractor(image, return_tensors = 'pt')
            
            # # Recuperate output from the model
            outputs = model(image_['pixel_values'], output_attentions = True)
            
            # Recuperate the predictions
            predictions = torch.argmax(outputs.logits, axis = -1)
            
            # Write the prediction to the middle
            mid.markdown(f"<h2 style='text-align: center; padding: 2cm; color: black; background-color: orange; border: darkorange solid 0.3px; box-shadow: 0.2px 0.2px 0.6px 0.1px gray'>{model.config.id2label[predictions[0].item()]}</h2>", unsafe_allow_html=True)
            
            # Let us recuperate the attention
            attention = outputs.attentions[-1][0]
            
            # Let us recuperate the attention image
            attention_image = get_attention(image, attention, size = (224, 224), patch_size = (14, 14))

            # Let us transform the attention image to a opencv image
            attention_image = cv2.cvtColor(attention_image.astype('float32'), cv2.COLOR_RGB2BGR)
            
            # Let us display the attention image
            right.header("Attention")
            
            right.image(attention_image, channels='BGR')
            
            right.markdown("""---""")