=
adding app with description
d025a55
raw
history blame
4.66 kB
from transformers import ViTForImageClassification, ViTFeatureExtractor
from fake_face_detection.metrics.make_predictions import get_attention
from torchvision import transforms
import streamlit as st
from PIL import Image
import numpy as np
import pickle
import torch
import cv2
# set the color of the header
def header(text):
st.markdown(f"<h1 style = 'color: #4B4453; text-align: center'>{text}</h1>", unsafe_allow_html=True)
st.markdown("""---""")
# initialize the size
size = (224, 224)
# let us add a header
header("FAKE AND REAL FACE DETECTION")
# let us add an expander to write some description of the application
expander = st.expander('Description', expanded=True)
with expander:
st.write('''This website has for purpose to help internet users
to know if an profil is safe by verifying if the face
display on it is verifiable. You can download the image
of a person in Facebook, Whatsapp or any other social media
and add in here and click on the submit button to obtain
the result (fake or real). You will also obtain an
modification of the original image indicating which
part of it is suspect or make the site identify if the
image is real. Enjoy!''')
# let us initialize two columns
left, mid, right = st.columns(3)
# the following function will load the model (must be in cache)
@st.cache_resource
def get_model():
# let us load the image characteristics
with open('data/extractions/fake_real_dict.txt', 'rb') as f:
depick = pickle.Unpickler(f)
characs = depick.load()
# define the model name
model_name = 'google/vit-base-patch16-224-in21k'
# recuperate the model
model = ViTForImageClassification.from_pretrained(
'data\checkpoints\model_2yW4AcqNIb6zLKNIb6zLK',
num_labels = len(characs['ids']),
id2label = {name: key for key, name in characs['ids'].items()},
label2id = characs['ids']
)
# recuperate the feature_extractor
feature_extractor = ViTFeatureExtractor(model_name)
return model, feature_extractor
# let us add a file uploader
st.subheader("Choose an image to inspect")
file = st.file_uploader("", type='jpg')
# if the file is correctly uploaded make the next processes
if file is not None:
# convert the file to an opencv image
file_bytes = np.asarray(bytearray(file.read()), dtype=np.uint8)
opencv_image = cv2.imdecode(file_bytes, 1)
# resize the image
opencv_image = cv2.resize(opencv_image, size)
# Let us display the image
left.header("Loaded image")
left.image(opencv_image, channels='BGR')
left.markdown("""---""")
if left.button("SUBMIT"):
# Let us convert the image format to 'RGB'
image = cv2.cvtColor(opencv_image, cv2.COLOR_BGR2RGB)
# Let us convert from opencv image to pil image
image = Image.fromarray(image)
with torch.no_grad():
# Recuperate the model and the feature extractor
model, feature_extractor = get_model()
# Change to evaluation mode
_ = model.eval()
# Apply transformation on the image
image_ = feature_extractor(image, return_tensors = 'pt')
# # Recuperate output from the model
outputs = model(image_['pixel_values'], output_attentions = True)
# Recuperate the predictions
predictions = torch.argmax(outputs.logits, axis = -1)
# Write the prediction to the middle
mid.markdown(f"<h2 style='text-align: center; padding: 2cm; color: black; background-color: orange; border: darkorange solid 0.3px; box-shadow: 0.2px 0.2px 0.6px 0.1px gray'>{model.config.id2label[predictions[0].item()]}</h2>", unsafe_allow_html=True)
# Let us recuperate the attention
attention = outputs.attentions[-1][0]
# Let us recuperate the attention image
attention_image = get_attention(image, attention, size = (224, 224), patch_size = (14, 14))
# Let us transform the attention image to a opencv image
attention_image = cv2.cvtColor(attention_image.astype('float32'), cv2.COLOR_RGB2BGR)
# Let us display the attention image
right.header("Attention")
right.image(attention_image, channels='BGR')
right.markdown("""---""")