File size: 4,135 Bytes
783053f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from transformers import ViTForImageClassification, ViTFeatureExtractor
from fake_face_detection.metrics.make_predictions import get_attention
from torchvision import transforms
import streamlit as st
from PIL import Image
import numpy as np
import pickle
import torch
import cv2

# set the color of the header
def header(text):
    st.markdown(f"<h1 style = 'color: #4B4453; text-align: center'>{text}</h1>", unsafe_allow_html=True)
    st.markdown("""---""")

# initialize the size
size = (224, 224)

# let us add a header
header("FAKE AND REAL FACE DETECTION")

# let us add an expander to write some description of the application
expander = st.expander('Description', expanded=True)

with expander:
    st.write('''This is a long text lorem ipsum dolor''')

# let us initialize two columns
left, mid, right = st.columns(3)

# the following function will load the model (must be in cache)
@st.cache_resource
def get_model():
    
    # let us load the image characteristics
    with open('data/extractions/fake_real_dict.txt', 'rb') as f:
        
        depick = pickle.Unpickler(f)
        
        characs = depick.load()
    
    # define the model name
    model_name = 'google/vit-base-patch16-224-in21k'
    
    # recuperate the model
    model = ViTForImageClassification.from_pretrained(
        'data\checkpoints\model_2yW4AcqNIb6zLKNIb6zLK',
        num_labels = len(characs['ids']),
        id2label = {name: key for key, name in characs['ids'].items()},
        label2id = characs['ids']
    )
    
    # recuperate the feature_extractor
    feature_extractor = ViTFeatureExtractor(model_name)
    
    return model, feature_extractor

# let us add a file uploader
st.subheader("Choose an image to inspect")
file = st.file_uploader("", type='jpg')

# if the file is correctly uploaded make the next processes
if file is not None:
    
    # convert the file to an opencv image
    file_bytes = np.asarray(bytearray(file.read()), dtype=np.uint8)
    
    opencv_image = cv2.imdecode(file_bytes, 1)
    
    # resize the image
    opencv_image = cv2.resize(opencv_image, size)
    
    # Let us display the image
    left.header("Loaded image")
    
    left.image(opencv_image, channels='BGR')
    
    left.markdown("""---""")
    
    if left.button("SUBMIT"):
        
        # Let us convert the image format to 'RGB'
        image = cv2.cvtColor(opencv_image, cv2.COLOR_BGR2RGB)
        
        # Let us convert from opencv image to pil image
        image = Image.fromarray(image)
        
        with torch.no_grad():
            
            # Recuperate the model and the feature extractor
            model, feature_extractor = get_model()
            
            # Change to evaluation mode
            _ = model.eval()
            
            # Apply transformation on the image
            image_ = feature_extractor(image, return_tensors = 'pt')
            
            # # Recuperate output from the model
            outputs = model(image_['pixel_values'], output_attentions = True)
            
            # Recuperate the predictions
            predictions = torch.argmax(outputs.logits, axis = -1)
            
            # Write the prediction to the middle
            mid.markdown(f"<h2 style='text-align: center; padding: 2cm; color: black; background-color: orange; border: darkorange solid 0.3px; box-shadow: 0.2px 0.2px 0.6px 0.1px gray'>{model.config.id2label[predictions[0].item()]}</h2>", unsafe_allow_html=True)
            
            # Let us recuperate the attention
            attention = outputs.attentions[-1][0]
            
            # Let us recuperate the attention image
            attention_image = get_attention(image, attention, size = (224, 224), patch_size = (14, 14))

            # Let us transform the attention image to a opencv image
            attention_image = cv2.cvtColor(attention_image.astype('float32'), cv2.COLOR_RGB2BGR)
            
            # Let us display the attention image
            right.header("Attention")
            
            right.image(attention_image, channels='BGR')
            
            right.markdown("""---""")