Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import tensorflow as tf | |
from tensorflow.keras.models import load_model | |
from tensorflow.keras.utils import to_categorical | |
from tensorflow.keras.datasets import mnist | |
import cv2 | |
from PIL import Image | |
# I/O image dimensions for display | |
DIMS = (100, 100) | |
# Load the trained model | |
mnist_model = load_model('mnist_model.h5') | |
adv_model = load_model('adv_model.h5') | |
# Load MNIST examples | |
(x_train, y_train), (x_test, y_test) = mnist.load_data() | |
# Select one example for each digit 0-9 | |
mnist_examples = [] | |
for digit in range(10): | |
idx = np.where(y_test == digit)[0][0] | |
mnist_examples.append([x_test[idx]]) | |
# Function to preprocess the image | |
def preprocess_image(image): | |
if isinstance(image, Image.Image): | |
image = np.array(image.convert('L')) # Convert to grayscale | |
image = cv2.resize(image, (28, 28)) # Resize to 28x28 for model input | |
image = image / 255.0 | |
image = image.reshape(1, 28, 28, 1) | |
return image | |
# Function to make predictions | |
def predict(image): | |
image = preprocess_image(image) | |
prediction = mnist_model.predict(image) | |
predicted_label = np.argmax(prediction) | |
return predicted_label, np.max(prediction) | |
# Function to compute gradients | |
def get_gradients(image, label): | |
image = tf.convert_to_tensor(image.reshape((1, 28, 28, 1)), dtype=tf.float32) | |
with tf.GradientTape() as tape: | |
tape.watch(image) | |
prediction = mnist_model(image) | |
loss = tf.keras.losses.categorical_crossentropy([label], prediction) | |
gradients = tape.gradient(loss, image) | |
return gradients.numpy().reshape(28, 28) | |
# Function to progressively mask image and observe changes | |
def progressively_mask_image(image, steps=100, increment=5): | |
image = preprocess_image(image).reshape(28, 28) | |
label = np.argmax(mnist_model.predict(image.reshape(1, 28, 28, 1))) | |
gradients = get_gradients(image, to_categorical(label, 10)) | |
modified_image = np.copy(image) | |
original_prediction = mnist_model.predict(image.reshape(1, 28, 28, 1)) | |
original_label = np.argmax(original_prediction) | |
for i in range(1, steps + 1): | |
threshold = np.percentile(np.abs(gradients), 100 - i * increment) | |
mask = np.abs(gradients) > threshold | |
modified_image[mask] = 0 | |
modified_prediction = mnist_model.predict(modified_image.reshape(1, 28, 28, 1)) | |
predicted_label = np.argmax(modified_prediction) | |
if predicted_label != original_label: | |
break | |
return cv2.resize(modified_image, DIMS), original_label, predicted_label | |
# Gradio interface functions | |
def gradio_predict(image): | |
predicted_label, confidence = predict(image) | |
return f"Predicted Label: {predicted_label}, Confidence: {confidence:.4f}" | |
def gradio_mask(image, steps, increment=1): | |
modified_image, original_label, predicted_label = progressively_mask_image(image, steps, increment) | |
return modified_image, f"Original Label: {original_label}, New Label: {predicted_label}" | |
# FGSM attack function | |
def fgsm_attack(image, epsilon, data_grad): | |
sign_of_grad = tf.sign(data_grad) | |
perturbed_image = image + epsilon * sign_of_grad | |
perturbed_image = tf.clip_by_value(perturbed_image, 0, 1) | |
return perturbed_image | |
# Create adversarial example function | |
def create_adversarial_pattern(input_image, input_label): | |
with tf.GradientTape() as tape: | |
tape.watch(input_image) | |
prediction = adv_model(input_image) | |
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)(input_label, prediction) | |
gradient = tape.gradient(loss, input_image) | |
return gradient | |
# Generate adversarial examples | |
epsilon = 0.20 # Tweak epsilon to change the intensity of perturbation | |
adversarial_examples = [] | |
for i in range(10): | |
img = preprocess_image(mnist_examples[i][0]) | |
img = tf.convert_to_tensor(img, dtype=tf.float32) | |
label = tf.reshape(tf.convert_to_tensor([i]), [1, 1]) | |
perturbations = create_adversarial_pattern(img, label) | |
adv_x = fgsm_attack(img, epsilon, perturbations) | |
adversarial_examples.append([cv2.resize(adv_x.numpy().squeeze(), DIMS)]) | |
# Resize the examples to 100 by 100 | |
mnist_examples = [[cv2.resize(example[0], DIMS)] for example in mnist_examples] | |
class GradioInterface: | |
def __init__(self): | |
self.preloaded_examples = mnist_examples | |
self.adversarial_examples = adversarial_examples | |
def create_interface(self): | |
app_styles = """ | |
<style> | |
/* Global Styles */ | |
body, #root { | |
font-family: Helvetica, Arial, sans-serif; | |
background-color: #1a1a1a; | |
color: #fafafa; | |
} | |
/* Header Styles */ | |
.app-header { | |
background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%); | |
padding: 24px; | |
border-radius: 8px; | |
margin-bottom: 24px; | |
text-align: center; | |
} | |
.app-title { | |
font-size: 48px; | |
margin: 0; | |
color: #fafafa; | |
} | |
.app-subtitle { | |
font-size: 24px; | |
margin: 8px 0 16px; | |
color: #fafafa; | |
} | |
.app-description { | |
font-size: 16px; | |
line-height: 1.6; | |
opacity: 0.8; | |
margin-bottom: 24px; | |
} | |
/* Button Styles */ | |
.publication-links { | |
display: flex; | |
justify-content: center; | |
flex-wrap: wrap; | |
gap: 8px; | |
margin-bottom: 16px; | |
} | |
.publication-link { | |
display: inline-flex; | |
align-items: center; | |
padding: 8px 16px; | |
background-color: #333; | |
color: #fff !important; | |
text-decoration: none !important; | |
border-radius: 20px; | |
font-size: 14px; | |
transition: background-color 0.3s; | |
} | |
.publication-link:hover { | |
background-color: #555; | |
} | |
.publication-link i { | |
margin-right: 8px; | |
} | |
/* Content Styles */ | |
.content-container { | |
background-color: #2a2a2a; | |
border-radius: 8px; | |
padding: 24px; | |
margin-bottom: 24px; | |
} | |
/* Image Styles */ | |
.image-preview img { | |
max-width: 512px; | |
max-height: 512px; | |
margin: 0 auto; | |
border-radius: 4px; | |
display: block; | |
object-fit: contain; | |
} | |
/* Control Styles */ | |
.control-panel { | |
background-color: #333; | |
padding: 16px; | |
border-radius: 8px; | |
margin-top: 16px; | |
} | |
/* Gradio Component Overrides */ | |
.gr-button { | |
background-color: #4a4a4a; | |
color: #fff; | |
border: none; | |
border-radius: 4px; | |
padding: 8px 16px; | |
cursor: pointer; | |
transition: background-color 0.3s; | |
} | |
.gr-button:hover { | |
background-color: #5a5a5a; | |
} | |
.gr-input, .gr-dropdown { | |
background-color: #3a3a3a; | |
color: #fff; | |
border: 1px solid #4a4a4a; | |
border-radius: 4px; | |
padding: 8px; | |
} | |
.gr-form { | |
background-color: transparent; | |
} | |
.gr-panel { | |
border: none; | |
background-color: transparent; | |
} | |
/* Override any conflicting styles from Bulma */ | |
.button.is-normal.is-rounded.is-dark { | |
color: #fff !important; | |
text-decoration: none !important; | |
} | |
</style> | |
""" | |
header_html = f""" | |
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css"> | |
<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.4/css/all.css"> | |
{app_styles} | |
<div class="app-header"> | |
<h1 class="app-title">Attribution Based Confidence Metric for Neural Networks</h1> | |
<h2 class="app-subtitle">Steven Fernandes, Ph.D.</h2> | |
</div> | |
""" | |
js_func = """ | |
function refresh() { | |
const url = new URL(window.location); | |
if (url.searchParams.get('__theme') !== 'dark') { | |
url.searchParams.set('__theme', 'dark'); | |
window.location.href = url.href; | |
} | |
} | |
""" | |
with gr.Blocks(js=js_func, theme=gr.themes.Default()) as demo: | |
gr.HTML(header_html) | |
with gr.Row(elem_classes="content-container"): | |
with gr.Column(): | |
input_image = gr.Image(label="Input Image", type="pil", format="png", elem_classes="image-preview") | |
steps_input = gr.Slider(minimum=1, maximum=100, label="Attributions Drop Percentage", step=1, value=5) | |
examples = gr.Examples( | |
examples=self.preloaded_examples, | |
inputs=input_image, | |
label="MNIST Examples" | |
) | |
adv_examples = gr.Examples( | |
examples=self.adversarial_examples, | |
inputs=input_image, | |
label="Adversarial Examples" | |
) | |
with gr.Column(): | |
result = gr.Image(label="Result", elem_classes="image-preview") | |
prediction_details = gr.Textbox(label="Prediction Details") | |
run_button = gr.Button("Run", elem_classes="gr-button") | |
run_button.click( | |
fn=gradio_mask, | |
inputs=[input_image, steps_input], | |
outputs=[result, prediction_details], | |
) | |
return demo | |
def main(): | |
interface = GradioInterface() | |
demo = interface.create_interface() | |
demo.launch(debug=True) | |
if __name__ == "__main__": | |
main() |