Spaces:
Build error
Build error
File size: 4,660 Bytes
783053f d025a55 783053f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
from transformers import ViTForImageClassification, ViTFeatureExtractor
from fake_face_detection.metrics.make_predictions import get_attention
from torchvision import transforms
import streamlit as st
from PIL import Image
import numpy as np
import pickle
import torch
import cv2
# set the color of the header
def header(text):
st.markdown(f"<h1 style = 'color: #4B4453; text-align: center'>{text}</h1>", unsafe_allow_html=True)
st.markdown("""---""")
# initialize the size
size = (224, 224)
# let us add a header
header("FAKE AND REAL FACE DETECTION")
# let us add an expander to write some description of the application
expander = st.expander('Description', expanded=True)
with expander:
st.write('''This website has for purpose to help internet users
to know if an profil is safe by verifying if the face
display on it is verifiable. You can download the image
of a person in Facebook, Whatsapp or any other social media
and add in here and click on the submit button to obtain
the result (fake or real). You will also obtain an
modification of the original image indicating which
part of it is suspect or make the site identify if the
image is real. Enjoy!''')
# let us initialize two columns
left, mid, right = st.columns(3)
# the following function will load the model (must be in cache)
@st.cache_resource
def get_model():
# let us load the image characteristics
with open('data/extractions/fake_real_dict.txt', 'rb') as f:
depick = pickle.Unpickler(f)
characs = depick.load()
# define the model name
model_name = 'google/vit-base-patch16-224-in21k'
# recuperate the model
model = ViTForImageClassification.from_pretrained(
'data\checkpoints\model_2yW4AcqNIb6zLKNIb6zLK',
num_labels = len(characs['ids']),
id2label = {name: key for key, name in characs['ids'].items()},
label2id = characs['ids']
)
# recuperate the feature_extractor
feature_extractor = ViTFeatureExtractor(model_name)
return model, feature_extractor
# let us add a file uploader
st.subheader("Choose an image to inspect")
file = st.file_uploader("", type='jpg')
# if the file is correctly uploaded make the next processes
if file is not None:
# convert the file to an opencv image
file_bytes = np.asarray(bytearray(file.read()), dtype=np.uint8)
opencv_image = cv2.imdecode(file_bytes, 1)
# resize the image
opencv_image = cv2.resize(opencv_image, size)
# Let us display the image
left.header("Loaded image")
left.image(opencv_image, channels='BGR')
left.markdown("""---""")
if left.button("SUBMIT"):
# Let us convert the image format to 'RGB'
image = cv2.cvtColor(opencv_image, cv2.COLOR_BGR2RGB)
# Let us convert from opencv image to pil image
image = Image.fromarray(image)
with torch.no_grad():
# Recuperate the model and the feature extractor
model, feature_extractor = get_model()
# Change to evaluation mode
_ = model.eval()
# Apply transformation on the image
image_ = feature_extractor(image, return_tensors = 'pt')
# # Recuperate output from the model
outputs = model(image_['pixel_values'], output_attentions = True)
# Recuperate the predictions
predictions = torch.argmax(outputs.logits, axis = -1)
# Write the prediction to the middle
mid.markdown(f"<h2 style='text-align: center; padding: 2cm; color: black; background-color: orange; border: darkorange solid 0.3px; box-shadow: 0.2px 0.2px 0.6px 0.1px gray'>{model.config.id2label[predictions[0].item()]}</h2>", unsafe_allow_html=True)
# Let us recuperate the attention
attention = outputs.attentions[-1][0]
# Let us recuperate the attention image
attention_image = get_attention(image, attention, size = (224, 224), patch_size = (14, 14))
# Let us transform the attention image to a opencv image
attention_image = cv2.cvtColor(attention_image.astype('float32'), cv2.COLOR_RGB2BGR)
# Let us display the attention image
right.header("Attention")
right.image(attention_image, channels='BGR')
right.markdown("""---""")
|