MNIST_ABC / app.py
AI-RESEARCHER-2024's picture
Update app.py
4909156 verified
import gradio as gr
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import mnist
import cv2
from PIL import Image
# I/O image dimensions for display
DIMS = (100, 100)
# Load the trained model
mnist_model = load_model('mnist_model.h5')
adv_model = load_model('adv_model.h5')
# Load MNIST examples
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# Select one example for each digit 0-9
mnist_examples = []
for digit in range(10):
idx = np.where(y_test == digit)[0][0]
mnist_examples.append([x_test[idx]])
# Function to preprocess the image
def preprocess_image(image):
if isinstance(image, Image.Image):
image = np.array(image.convert('L')) # Convert to grayscale
image = cv2.resize(image, (28, 28)) # Resize to 28x28 for model input
image = image / 255.0
image = image.reshape(1, 28, 28, 1)
return image
# Function to make predictions
def predict(image):
image = preprocess_image(image)
prediction = mnist_model.predict(image)
predicted_label = np.argmax(prediction)
return predicted_label, np.max(prediction)
# Function to compute gradients
def get_gradients(image, label):
image = tf.convert_to_tensor(image.reshape((1, 28, 28, 1)), dtype=tf.float32)
with tf.GradientTape() as tape:
tape.watch(image)
prediction = mnist_model(image)
loss = tf.keras.losses.categorical_crossentropy([label], prediction)
gradients = tape.gradient(loss, image)
return gradients.numpy().reshape(28, 28)
# Function to progressively mask image and observe changes
def progressively_mask_image(image, steps=100, increment=5):
image = preprocess_image(image).reshape(28, 28)
label = np.argmax(mnist_model.predict(image.reshape(1, 28, 28, 1)))
gradients = get_gradients(image, to_categorical(label, 10))
modified_image = np.copy(image)
original_prediction = mnist_model.predict(image.reshape(1, 28, 28, 1))
original_label = np.argmax(original_prediction)
for i in range(1, steps + 1):
threshold = np.percentile(np.abs(gradients), 100 - i * increment)
mask = np.abs(gradients) > threshold
modified_image[mask] = 0
modified_prediction = mnist_model.predict(modified_image.reshape(1, 28, 28, 1))
predicted_label = np.argmax(modified_prediction)
if predicted_label != original_label:
break
return cv2.resize(modified_image, DIMS), original_label, predicted_label
# Gradio interface functions
def gradio_predict(image):
predicted_label, confidence = predict(image)
return f"Predicted Label: {predicted_label}, Confidence: {confidence:.4f}"
def gradio_mask(image, steps, increment=1):
modified_image, original_label, predicted_label = progressively_mask_image(image, steps, increment)
return modified_image, f"Original Label: {original_label}, New Label: {predicted_label}"
# FGSM attack function
def fgsm_attack(image, epsilon, data_grad):
sign_of_grad = tf.sign(data_grad)
perturbed_image = image + epsilon * sign_of_grad
perturbed_image = tf.clip_by_value(perturbed_image, 0, 1)
return perturbed_image
# Create adversarial example function
def create_adversarial_pattern(input_image, input_label):
with tf.GradientTape() as tape:
tape.watch(input_image)
prediction = adv_model(input_image)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)(input_label, prediction)
gradient = tape.gradient(loss, input_image)
return gradient
# Generate adversarial examples
epsilon = 0.20 # Tweak epsilon to change the intensity of perturbation
adversarial_examples = []
for i in range(10):
img = preprocess_image(mnist_examples[i][0])
img = tf.convert_to_tensor(img, dtype=tf.float32)
label = tf.reshape(tf.convert_to_tensor([i]), [1, 1])
perturbations = create_adversarial_pattern(img, label)
adv_x = fgsm_attack(img, epsilon, perturbations)
adversarial_examples.append([cv2.resize(adv_x.numpy().squeeze(), DIMS)])
# Resize the examples to 100 by 100
mnist_examples = [[cv2.resize(example[0], DIMS)] for example in mnist_examples]
class GradioInterface:
def __init__(self):
self.preloaded_examples = mnist_examples
self.adversarial_examples = adversarial_examples
def create_interface(self):
app_styles = """
<style>
/* Global Styles */
body, #root {
font-family: Helvetica, Arial, sans-serif;
background-color: #1a1a1a;
color: #fafafa;
}
/* Header Styles */
.app-header {
background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%);
padding: 24px;
border-radius: 8px;
margin-bottom: 24px;
text-align: center;
}
.app-title {
font-size: 48px;
margin: 0;
color: #fafafa;
}
.app-subtitle {
font-size: 24px;
margin: 8px 0 16px;
color: #fafafa;
}
.app-description {
font-size: 16px;
line-height: 1.6;
opacity: 0.8;
margin-bottom: 24px;
}
/* Button Styles */
.publication-links {
display: flex;
justify-content: center;
flex-wrap: wrap;
gap: 8px;
margin-bottom: 16px;
}
.publication-link {
display: inline-flex;
align-items: center;
padding: 8px 16px;
background-color: #333;
color: #fff !important;
text-decoration: none !important;
border-radius: 20px;
font-size: 14px;
transition: background-color 0.3s;
}
.publication-link:hover {
background-color: #555;
}
.publication-link i {
margin-right: 8px;
}
/* Content Styles */
.content-container {
background-color: #2a2a2a;
border-radius: 8px;
padding: 24px;
margin-bottom: 24px;
}
/* Image Styles */
.image-preview img {
max-width: 512px;
max-height: 512px;
margin: 0 auto;
border-radius: 4px;
display: block;
object-fit: contain;
}
/* Control Styles */
.control-panel {
background-color: #333;
padding: 16px;
border-radius: 8px;
margin-top: 16px;
}
/* Gradio Component Overrides */
.gr-button {
background-color: #4a4a4a;
color: #fff;
border: none;
border-radius: 4px;
padding: 8px 16px;
cursor: pointer;
transition: background-color 0.3s;
}
.gr-button:hover {
background-color: #5a5a5a;
}
.gr-input, .gr-dropdown {
background-color: #3a3a3a;
color: #fff;
border: 1px solid #4a4a4a;
border-radius: 4px;
padding: 8px;
}
.gr-form {
background-color: transparent;
}
.gr-panel {
border: none;
background-color: transparent;
}
/* Override any conflicting styles from Bulma */
.button.is-normal.is-rounded.is-dark {
color: #fff !important;
text-decoration: none !important;
}
</style>
"""
header_html = f"""
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css">
<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.4/css/all.css">
{app_styles}
<div class="app-header">
<h1 class="app-title">Attribution Based Confidence Metric for Neural Networks</h1>
<h2 class="app-subtitle">Steven Fernandes, Ph.D.</h2>
</div>
"""
js_func = """
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'dark') {
url.searchParams.set('__theme', 'dark');
window.location.href = url.href;
}
}
"""
with gr.Blocks(js=js_func, theme=gr.themes.Default()) as demo:
gr.HTML(header_html)
with gr.Row(elem_classes="content-container"):
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil", format="png", elem_classes="image-preview")
steps_input = gr.Slider(minimum=1, maximum=100, label="Attributions Drop Percentage", step=1, value=5)
examples = gr.Examples(
examples=self.preloaded_examples,
inputs=input_image,
label="MNIST Examples"
)
adv_examples = gr.Examples(
examples=self.adversarial_examples,
inputs=input_image,
label="Adversarial Examples"
)
with gr.Column():
result = gr.Image(label="Result", elem_classes="image-preview")
prediction_details = gr.Textbox(label="Prediction Details")
run_button = gr.Button("Run", elem_classes="gr-button")
run_button.click(
fn=gradio_mask,
inputs=[input_image, steps_input],
outputs=[result, prediction_details],
)
return demo
def main():
interface = GradioInterface()
demo = interface.create_interface()
demo.launch(debug=True)
if __name__ == "__main__":
main()