Spaces:
Build error
Build error
from transformers import ViTForImageClassification, ViTFeatureExtractor | |
from fake_face_detection.metrics.make_predictions import get_attention | |
from torchvision import transforms | |
import streamlit as st | |
from PIL import Image | |
import numpy as np | |
import pickle | |
import torch | |
import cv2 | |
# set the color of the header | |
def header(text): | |
st.markdown(f"<h1 style = 'color: #4B4453; text-align: center'>{text}</h1>", unsafe_allow_html=True) | |
st.markdown("""---""") | |
# initialize the size | |
size = (224, 224) | |
# let us add a header | |
header("FAKE AND REAL FACE DETECTION") | |
# let us add an expander to write some description of the application | |
expander = st.expander('Description', expanded=True) | |
with expander: | |
st.write('''This website aims to help internet users | |
know if a profile is safe by verifying if its displayed face is verifiable. You can download the image | |
of a person on Facebook, Whatsapp, or any other social media | |
and add it here and click on the submit button to obtain | |
the result (fake or actual). You will also receive a | |
modification of the original image indicating which | |
part of it is suspect or make the site identify if the | |
picture is accurate. Enjoy!''') | |
# let us initialize two columns | |
left, mid, right = st.columns(3) | |
# the following function will load the model (must be in cache) | |
def get_model(): | |
# let us load the image characteristics | |
with open('data/extractions/fake_real_dict.txt', 'rb') as f: | |
depick = pickle.Unpickler(f) | |
characs = depick.load() | |
# define the model name | |
model_name = 'google/vit-base-patch16-224-in21k' | |
# recuperate the model | |
model = ViTForImageClassification.from_pretrained( | |
'data/checkpoints/model_lhGqMDq/checkpoint-440', | |
num_labels = len(characs['ids']), | |
id2label = {name: key for key, name in characs['ids'].items()}, | |
label2id = characs['ids'] | |
) | |
# recuperate the feature_extractor | |
feature_extractor = ViTFeatureExtractor(model_name) | |
return model, feature_extractor | |
# let us add a file uploader | |
st.subheader("Choose an image to inspect") | |
file = st.file_uploader("", type='jpg') | |
# if the file is correctly uploaded make the next processes | |
if file is not None: | |
# convert the file to an opencv image | |
file_bytes = np.asarray(bytearray(file.read()), dtype=np.uint8) | |
opencv_image = cv2.imdecode(file_bytes, 1) | |
# resize the image | |
opencv_image = cv2.resize(opencv_image, size) | |
# Let us display the image | |
left.header("Loaded image") | |
left.image(opencv_image, channels='BGR') | |
left.markdown("""---""") | |
# initiliaze the smoothing parameters | |
smooth_scale = st.sidebar.slider("Smooth scale", min_value=0.1, max_value =1.0, step = 0.1) | |
smooth_thres = st.sidebar.slider("Smooth thres", min_value=0.01, max_value =1.0, step = 0.01) | |
smooth_size = st.sidebar.slider("Smooth size", min_value=1, max_value =10) | |
smooth_iter = st.sidebar.slider("Smooth iter", min_value=1, max_value =10) | |
# add a side for the scaler and the head number | |
scale = st.sidebar.slider("Attention scale", min_value=30, max_value =200) | |
head = int(st.sidebar.selectbox("Attention head", options=list(range(1, 13)))) | |
if left.button("SUBMIT"): | |
# Let us convert the image format to 'RGB' | |
image = cv2.cvtColor(opencv_image, cv2.COLOR_BGR2RGB) | |
# Let us convert from opencv image to pil image | |
image = Image.fromarray(image) | |
with torch.no_grad(): | |
# Recuperate the model and the feature extractor | |
model, feature_extractor = get_model() | |
# Change to evaluation mode | |
_ = model.eval() | |
# Apply transformation on the image | |
image_ = feature_extractor(image, return_tensors = 'pt') | |
# # Recuperate output from the model | |
outputs = model(image_['pixel_values'], output_attentions = True) | |
# Recuperate the predictions | |
predictions = torch.argmax(outputs.logits, axis = -1) | |
# Write the prediction to the middle | |
mid.markdown(f"<h2 style='text-align: center; padding: 2cm; color: black; background-color: orange; border: darkorange solid 0.3px; box-shadow: 0.2px 0.2px 0.6px 0.1px gray'>{model.config.id2label[predictions[0].item()]}</h2>", unsafe_allow_html=True) | |
# Let us recuperate the attention | |
attention = outputs.attentions[-1][0] | |
# Let us recuperate the attention image | |
attention_image = get_attention(image, attention, size = (224, 224), patch_size = (14, 14), scale = scale, head = head, smooth_scale = smooth_scale, smooth_thres=smooth_thres, smooth_size = smooth_size, smooth_iter = smooth_iter) | |
# Let us transform the attention image to a opencv image | |
attention_image = cv2.cvtColor(attention_image.astype('float32'), cv2.COLOR_RGB2BGR) | |
# Let us display the attention image | |
right.header("Attention") | |
right.image(attention_image, channels='BGR') | |
right.markdown("""---""") | |