Spaces:
Build error
Build error
File size: 4,135 Bytes
783053f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
from transformers import ViTForImageClassification, ViTFeatureExtractor
from fake_face_detection.metrics.make_predictions import get_attention
from torchvision import transforms
import streamlit as st
from PIL import Image
import numpy as np
import pickle
import torch
import cv2
# set the color of the header
def header(text):
st.markdown(f"<h1 style = 'color: #4B4453; text-align: center'>{text}</h1>", unsafe_allow_html=True)
st.markdown("""---""")
# initialize the size
size = (224, 224)
# let us add a header
header("FAKE AND REAL FACE DETECTION")
# let us add an expander to write some description of the application
expander = st.expander('Description', expanded=True)
with expander:
st.write('''This is a long text lorem ipsum dolor''')
# let us initialize two columns
left, mid, right = st.columns(3)
# the following function will load the model (must be in cache)
@st.cache_resource
def get_model():
# let us load the image characteristics
with open('data/extractions/fake_real_dict.txt', 'rb') as f:
depick = pickle.Unpickler(f)
characs = depick.load()
# define the model name
model_name = 'google/vit-base-patch16-224-in21k'
# recuperate the model
model = ViTForImageClassification.from_pretrained(
'data\checkpoints\model_2yW4AcqNIb6zLKNIb6zLK',
num_labels = len(characs['ids']),
id2label = {name: key for key, name in characs['ids'].items()},
label2id = characs['ids']
)
# recuperate the feature_extractor
feature_extractor = ViTFeatureExtractor(model_name)
return model, feature_extractor
# let us add a file uploader
st.subheader("Choose an image to inspect")
file = st.file_uploader("", type='jpg')
# if the file is correctly uploaded make the next processes
if file is not None:
# convert the file to an opencv image
file_bytes = np.asarray(bytearray(file.read()), dtype=np.uint8)
opencv_image = cv2.imdecode(file_bytes, 1)
# resize the image
opencv_image = cv2.resize(opencv_image, size)
# Let us display the image
left.header("Loaded image")
left.image(opencv_image, channels='BGR')
left.markdown("""---""")
if left.button("SUBMIT"):
# Let us convert the image format to 'RGB'
image = cv2.cvtColor(opencv_image, cv2.COLOR_BGR2RGB)
# Let us convert from opencv image to pil image
image = Image.fromarray(image)
with torch.no_grad():
# Recuperate the model and the feature extractor
model, feature_extractor = get_model()
# Change to evaluation mode
_ = model.eval()
# Apply transformation on the image
image_ = feature_extractor(image, return_tensors = 'pt')
# # Recuperate output from the model
outputs = model(image_['pixel_values'], output_attentions = True)
# Recuperate the predictions
predictions = torch.argmax(outputs.logits, axis = -1)
# Write the prediction to the middle
mid.markdown(f"<h2 style='text-align: center; padding: 2cm; color: black; background-color: orange; border: darkorange solid 0.3px; box-shadow: 0.2px 0.2px 0.6px 0.1px gray'>{model.config.id2label[predictions[0].item()]}</h2>", unsafe_allow_html=True)
# Let us recuperate the attention
attention = outputs.attentions[-1][0]
# Let us recuperate the attention image
attention_image = get_attention(image, attention, size = (224, 224), patch_size = (14, 14))
# Let us transform the attention image to a opencv image
attention_image = cv2.cvtColor(attention_image.astype('float32'), cv2.COLOR_RGB2BGR)
# Let us display the attention image
right.header("Attention")
right.image(attention_image, channels='BGR')
right.markdown("""---""")
|