ff98's picture
model updated
8018e91
raw
history blame
7.01 kB
# imports
import streamlit as st
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, BatchNormalization, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import numpy as np
# load css
def load_local_css(file_name):
with open(file_name) as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
load_local_css("./styles/style.css")
# bootstrap
st.markdown(
"""<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet" integrity="sha384-QWTKZyjpPEjISv5WaRU9OFeRpok6YctnYmDr5pNlyT2bRjXh0JMhjY6hW+ALEwIH" crossorigin="anonymous">""",
unsafe_allow_html=True
)
@st.cache_resource
def load_models():
# Load all models at once
eff_net_model = tf.keras.models.load_model('EfficientNet_Models/efficientnetb3_binary_classifier_8.h5')
eff_net_art_model = tf.keras.models.load_model('EfficientNet_Models/EfficientNet_fine_tune_art_model.h5')
cnn_model = 'CNN_model_weight/model_weights.weights.h5'
return eff_net_model, eff_net_art_model, cnn_model
# Access cached models
eff_net_model, eff_net_art_model, cnn_model = load_models()
# CNN model
def run_cnn(img_arr):
my_model = Sequential()
my_model.add(Conv2D(
filters=16,
kernel_size=(3, 3),
strides=(1, 1),
activation='relu',
input_shape=(256, 256, 3)
))
my_model.add(BatchNormalization())
my_model.add(MaxPooling2D())
my_model.add(Conv2D(filters=32, kernel_size=(3, 3), activation='relu'))
my_model.add(BatchNormalization())
my_model.add(MaxPooling2D())
my_model.add(Conv2D(filters=64, kernel_size=(3, 3), activation='relu'))
my_model.add(BatchNormalization())
my_model.add(MaxPooling2D())
my_model.add(Flatten())
my_model.add(Dense(512, activation='relu'))
my_model.add(Dropout(0.09))
my_model.add(Dense(1, activation='sigmoid'))
my_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# Load the pre-trained weights
my_model.load_weights(cnn_model)
prediction = my_model.predict(img_arr)
return prediction
def run_effNet(img_arr):
try:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)
except ValueError:
strategy = tf.distribute.get_strategy()
with strategy.scope():
prediction = eff_net_model.predict(img_arr)
return prediction
def run_effNet_Art(img_arr):
try:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)
except ValueError:
strategy = tf.distribute.get_strategy()
with strategy.scope():
prediction = eff_net_art_model.predict(img_arr)
return prediction
def pre_process_img_effNet(image):
img = load_img(image, target_size=(300, 300)) # Resize image to model input size
img_arr = img_to_array(img) # Convert to array
img_arr = np.expand_dims(img_arr, axis=0) # Add batch dimension
result = run_effNet(img_arr)
return result
def pre_process_img_effNetArt(image):
img = load_img(image, target_size=(224, 224)) # Resize image to model input size
img_arr = img_to_array(img) # Convert to array
img_arr = np.expand_dims(img_arr, axis=0) # Add batch dimension
result = run_effNet_Art(img_arr)
return result
# preprocess image for cnn
def pre_process_img(image):
# Load and preprocess the image
input_picture = load_img(image, target_size=(256, 256))
img_arr = img_to_array(input_picture) / 255.0 # Normalize the image
img_arr = img_arr.reshape((1, 256, 256, 3)) # Add batch dimension
result = run_cnn(img_arr)
return result
# title
st.markdown(
"""<p class = "title"> AI vs REAL Image Detection </p>""",
unsafe_allow_html= True
)
# upload image
st.markdown(
"""<p class = "upload_line"> Please upload the image </p>""",
unsafe_allow_html= True
)
# introduce states
if "prev_image" not in st.session_state:
st.session_state.prev_image = None
if "reset_model" not in st.session_state:
st.session_state.reset_model = False
if "model_key" not in st.session_state:
st.session_state.model_key = "default_model_key"
user_image = st.file_uploader("png, jpg, or jpeg image", ['png', 'jpg', 'jpeg'], label_visibility='hidden')
if user_image != st.session_state.prev_image:
if st.session_state.prev_image is not None:
st.session_state.model_key = "reset_model_key" if st.session_state.model_key == "default_model_key" else "default_model_key"
st.session_state.reset_model = True
st.session_state.prev_image = user_image # set prev image to current image
model_name = st.selectbox(
'Choose a model',
['CNN', 'Efficientnet', 'Efficientnet Art'],
index=None,
placeholder='choose an option',
key=st.session_state.model_key
)
result_placeholder = st.empty()
# design animation elements
with open("styles/detectiveMag.svg", "r") as file:
svg_content_detective_Mag = file.read()
# First magnifying glass starts at bottom-right
st.markdown(
f"<div class='detectiveMag1' style='bottom: 0%; right: 0%;'>{svg_content_detective_Mag}</div>",
unsafe_allow_html=True
)
# Second magnifying glass starts slightly higher up the diagonal
st.markdown(
f"<div class='detectiveMag2' style='bottom: 10%; right: 10%;'>{svg_content_detective_Mag}</div>",
unsafe_allow_html=True
)
# Third magnifying glass starts further up the diagonal
st.markdown(
f"<div class='detectiveMag3' style='bottom: 20%; right: 20%;'>{svg_content_detective_Mag}</div>",
unsafe_allow_html=True
)
if user_image is not None and model_name is not None:
predictions = []
if model_name == 'CNN':
print('CNN is running')
predictions = pre_process_img(user_image)
elif model_name == 'Efficientnet':
print('Effnet is running')
predictions = pre_process_img_effNet(user_image)
elif model_name == 'Efficientnet Art':
print('Effnet Art is running')
predictions = pre_process_img_effNetArt(user_image)
if predictions[0] < 0.5:
result_word = "FAKE"
else:
result_word = "REAL"
if user_image is not None:
if len(predictions) > 0:
result_placeholder.markdown(f"<div class='result'> <span class = 'prediction'>Prediction: {predictions[0][0]}</span> <br> It is a <span class = resultword> {result_word} </span> image. </div>", unsafe_allow_html=True)
print(model_name)
print(predictions[0])