MNIST_ABC / app.py
AI-RESEARCHER-2024's picture
Update app.py
f9dbf06 verified
raw
history blame
8.79 kB
import gradio as gr
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import mnist
import cv2
# I/O image dimensions for display
DIMS = (100,100)
# Load the trained model
model = load_model('mnist_model.h5')
# Load MNIST examples
(x_train, y_train), (x_test, y_test) = mnist.load_data()
mnist_examples = [[x_test[i]] for i in range(10)] # Select first 10 examples and format as nested list
# resize the examples 100 by 100
mnist_examples = [[cv2.resize(x_test[i], DIMS)] for i in range(10)]
# Function to preprocess the image
def preprocess_image(image):
image = cv2.resize(image, (28, 28)) # Resize to 28x28 for model input
image = np.array(image) / 255.0
image = image.reshape(1, 28, 28, 1)
return image
# Function to make predictions
def predict(image):
image = preprocess_image(image)
prediction = model.predict(image)
predicted_label = np.argmax(prediction)
return predicted_label, np.max(prediction)
# Function to compute gradients
def get_gradients(image, label):
image = tf.convert_to_tensor(image.reshape((1, 28, 28, 1)), dtype=tf.float32)
with tf.GradientTape() as tape:
tape.watch(image)
prediction = model(image)
loss = tf.keras.losses.categorical_crossentropy([label], prediction)
gradients = tape.gradient(loss, image)
return gradients.numpy().reshape(28, 28)
# Function to progressively mask image and observe changes
def progressively_mask_image(image, steps=100, increment=5):
image = preprocess_image(image).reshape(28, 28)
label = np.argmax(model.predict(image.reshape(1, 28, 28, 1)))
gradients = get_gradients(image, to_categorical(label, 10))
modified_image = np.copy(image)
original_prediction = model.predict(image.reshape(1, 28, 28, 1))
original_label = np.argmax(original_prediction)
for i in range(1, steps + 1):
threshold = np.percentile(np.abs(gradients), 100 - i * increment)
mask = np.abs(gradients) > threshold
modified_image[mask] = 0
modified_prediction = model.predict(modified_image.reshape(1, 28, 28, 1))
predicted_label = np.argmax(modified_prediction)
if predicted_label != original_label:
break
return cv2.resize(modified_image, DIMS), original_label, predicted_label
# Gradio interface functions
def gradio_predict(image):
predicted_label, confidence = predict(image)
return f"Predicted Label: {predicted_label}, Confidence: {confidence:.4f}"
def gradio_mask(image, steps, increment):
modified_image, original_label, predicted_label = progressively_mask_image(image, steps, increment)
return modified_image, f"Original Label: {original_label}, New Label: {predicted_label}"
class GradioInterface:
def __init__(self):
self.preloaded_examples = self.preload_examples()
def preload_examples(self):
preloaded = {}
for model_name, example_dir in Config.EXAMPLES.items():
examples = [os.path.join(example_dir, img) for img in os.listdir(example_dir)]
preloaded[model_name] = examples
return preloaded
def create_interface(self):
app_styles = """
<style>
/* Global Styles */
body, #root {
font-family: Helvetica, Arial, sans-serif;
background-color: #1a1a1a;
color: #fafafa;
}
/* Header Styles */
.app-header {
background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%);
padding: 24px;
border-radius: 8px;
margin-bottom: 24px;
text-align: center;
}
.app-title {
font-size: 48px;
margin: 0;
color: #fafafa;
}
.app-subtitle {
font-size: 24px;
margin: 8px 0 16px;
color: #fafafa;
}
.app-description {
font-size: 16px;
line-height: 1.6;
opacity: 0.8;
margin-bottom: 24px;
}
/* Button Styles */
.publication-links {
display: flex;
justify-content: center;
flex-wrap: wrap;
gap: 8px;
margin-bottom: 16px;
}
.publication-link {
display: inline-flex;
align-items: center;
padding: 8px 16px;
background-color: #333;
color: #fff !important;
text-decoration: none !important;
border-radius: 20px;
font-size: 14px;
transition: background-color 0.3s;
}
.publication-link:hover {
background-color: #555;
}
.publication-link i {
margin-right: 8px;
}
/* Content Styles */
.content-container {
background-color: #2a2a2a;
border-radius: 8px;
padding: 24px;
margin-bottom: 24px;
}
/* Image Styles */
.image-preview img {
max-width: 512px;
max-height: 512px;
margin: 0 auto;
border-radius: 4px;
display: block;
object-fit: contain;
}
/* Control Styles */
.control-panel {
background-color: #333;
padding: 16px;
border-radius: 8px;
margin-top: 16px;
}
/* Gradio Component Overrides */
.gr-button {
background-color: #4a4a4a;
color: #fff;
border: none;
border-radius: 4px;
padding: 8px 16px;
cursor: pointer;
transition: background-color 0.3s;
}
.gr-button:hover {
background-color: #5a5a5a;
}
.gr-input, .gr-dropdown {
background-color: #3a3a3a;
color: #fff;
border: 1px solid #4a4a4a;
border-radius: 4px;
padding: 8px;
}
.gr-form {
background-color: transparent;
}
.gr-panel {
border: none;
background-color: transparent;
}
/* Override any conflicting styles from Bulma */
.button.is-normal.is-rounded.is-dark {
color: #fff !important;
text-decoration: none !important;
}
</style>
"""
header_html = f"""
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css">
<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.4/css/all.css">
{app_styles}
<div class="app-header">
<h1 class="app-title">Attribution Based Confidence Metric for Neural Networks</h1>
<h2 class="app-subtitle">Steven Fernandes, Ph.D.</h2>
</div>
"""
js_func = """
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'dark') {
url.searchParams.set('__theme', 'dark');
window.location.href = url.href;
}
}
"""
with gr.Blocks(js=js_func, theme=gr.themes.Default()) as demo:
gr.HTML(header_html)
with gr.Row(elem_classes="content-container"):
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil", format="png", elem_classes="image-preview")
steps_input = gr.Slider(minimum=1, maximum=100, label="Steps", step=1, value=100)
increment_input = gr.Slider(minimum=1, maximum=20, label="Increment", step=1, value=5)
with gr.Column():
result = gr.Image(label="Result", elem_classes="image-preview")
run_button = gr.Button("Run", elem_classes="gr-button")
run_button.click(
fn=gradio_mask,
inputs=[input_image, steps_input, increment_input],
outputs=[result, gr.Textbox(label="Prediction Details")],
)
return demo
def main():
interface = GradioInterface()
demo = interface.create_interface()
demo.launch(debug=True)
if __name__ == "__main__":
main()