Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import seaborn as sns
|
6 |
+
from tensorflow.keras.models import load_model
|
7 |
+
from tensorflow.keras.utils import to_categorical
|
8 |
+
from sklearn.metrics import confusion_matrix
|
9 |
+
from tensorflow.keras.datasets import mnist
|
10 |
+
import cv2
|
11 |
+
|
12 |
+
# I/O image dimensions for display
|
13 |
+
DIMS = (100,100)
|
14 |
+
# Load the trained model
|
15 |
+
model = load_model('model.h5')
|
16 |
+
|
17 |
+
# Load MNIST examples
|
18 |
+
(x_train, y_train), (x_test, y_test) = mnist.load_data()
|
19 |
+
mnist_examples = [[x_test[i]] for i in range(10)] # Select first 10 examples and format as nested list
|
20 |
+
# resize the examples 100 by 100
|
21 |
+
mnist_examples = [[cv2.resize(x_test[i], DIMS)] for i in range(10)]
|
22 |
+
|
23 |
+
# Function to preprocess the image
|
24 |
+
def preprocess_image(image):
|
25 |
+
image = cv2.resize(image, (28, 28)) # Resize to 28x28 for model input
|
26 |
+
image = np.array(image) / 255.0
|
27 |
+
image = image.reshape(1, 28, 28, 1)
|
28 |
+
return image
|
29 |
+
|
30 |
+
# Function to make predictions
|
31 |
+
def predict(image):
|
32 |
+
image = preprocess_image(image)
|
33 |
+
prediction = model.predict(image)
|
34 |
+
predicted_label = np.argmax(prediction)
|
35 |
+
return predicted_label, np.max(prediction)
|
36 |
+
|
37 |
+
# Function to compute gradients
|
38 |
+
def get_gradients(image, label):
|
39 |
+
image = tf.convert_to_tensor(image.reshape((1, 28, 28, 1)), dtype=tf.float32)
|
40 |
+
with tf.GradientTape() as tape:
|
41 |
+
tape.watch(image)
|
42 |
+
prediction = model(image)
|
43 |
+
loss = tf.keras.losses.categorical_crossentropy([label], prediction)
|
44 |
+
gradients = tape.gradient(loss, image)
|
45 |
+
return gradients.numpy().reshape(28, 28)
|
46 |
+
|
47 |
+
# Function to progressively mask image and observe changes
|
48 |
+
def progressively_mask_image(image, steps=100, increment=5):
|
49 |
+
image = preprocess_image(image).reshape(28, 28)
|
50 |
+
label = np.argmax(model.predict(image.reshape(1, 28, 28, 1)))
|
51 |
+
gradients = get_gradients(image, to_categorical(label, 10))
|
52 |
+
|
53 |
+
modified_image = np.copy(image)
|
54 |
+
original_prediction = model.predict(image.reshape(1, 28, 28, 1))
|
55 |
+
original_label = np.argmax(original_prediction)
|
56 |
+
|
57 |
+
for i in range(1, steps + 1):
|
58 |
+
threshold = np.percentile(np.abs(gradients), 100 - i * increment)
|
59 |
+
mask = np.abs(gradients) > threshold
|
60 |
+
modified_image[mask] = 0
|
61 |
+
modified_prediction = model.predict(modified_image.reshape(1, 28, 28, 1))
|
62 |
+
predicted_label = np.argmax(modified_prediction)
|
63 |
+
if predicted_label != original_label:
|
64 |
+
break
|
65 |
+
|
66 |
+
return cv2.resize(modified_image, DIMS), original_label, predicted_label
|
67 |
+
|
68 |
+
# Gradio interface functions
|
69 |
+
def gradio_predict(image):
|
70 |
+
predicted_label, confidence = predict(image)
|
71 |
+
return f"Predicted Label: {predicted_label}, Confidence: {confidence:.4f}"
|
72 |
+
|
73 |
+
def gradio_mask(image, steps, increment):
|
74 |
+
modified_image, original_label, predicted_label = progressively_mask_image(image, steps, increment)
|
75 |
+
return modified_image, f"Original Label: {original_label}, New Label: {predicted_label}"
|
76 |
+
|
77 |
+
# Gradio interface
|
78 |
+
image_input = gr.Image(image_mode='L', label="Input Image")
|
79 |
+
steps_input = gr.Slider(minimum=1, maximum=100, label="Steps", step=1, value=100)
|
80 |
+
increment_input = gr.Slider(minimum=1, maximum=20, label="Increment", step=1, value=5)
|
81 |
+
|
82 |
+
|
83 |
+
gr.Interface(
|
84 |
+
fn=gradio_mask,
|
85 |
+
inputs=[image_input, steps_input, increment_input],
|
86 |
+
outputs=[
|
87 |
+
gr.Image(image_mode='L', label="Ouput Image"),
|
88 |
+
gr.Textbox(label="Prediction Details")
|
89 |
+
],
|
90 |
+
title="Progressive Masking",
|
91 |
+
description="Upload an image of a digit and observe how masking affects the model's prediction.",
|
92 |
+
examples=mnist_examples,
|
93 |
+
allow_flagging="never"
|
94 |
+
).launch()
|