AI-RESEARCHER-2024 commited on
Commit
601068b
·
verified ·
1 Parent(s): 75e93f3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ from tensorflow.keras.models import load_model
7
+ from tensorflow.keras.utils import to_categorical
8
+ from sklearn.metrics import confusion_matrix
9
+ from tensorflow.keras.datasets import mnist
10
+ import cv2
11
+
12
+ # I/O image dimensions for display
13
+ DIMS = (100,100)
14
+ # Load the trained model
15
+ model = load_model('model.h5')
16
+
17
+ # Load MNIST examples
18
+ (x_train, y_train), (x_test, y_test) = mnist.load_data()
19
+ mnist_examples = [[x_test[i]] for i in range(10)] # Select first 10 examples and format as nested list
20
+ # resize the examples 100 by 100
21
+ mnist_examples = [[cv2.resize(x_test[i], DIMS)] for i in range(10)]
22
+
23
+ # Function to preprocess the image
24
+ def preprocess_image(image):
25
+ image = cv2.resize(image, (28, 28)) # Resize to 28x28 for model input
26
+ image = np.array(image) / 255.0
27
+ image = image.reshape(1, 28, 28, 1)
28
+ return image
29
+
30
+ # Function to make predictions
31
+ def predict(image):
32
+ image = preprocess_image(image)
33
+ prediction = model.predict(image)
34
+ predicted_label = np.argmax(prediction)
35
+ return predicted_label, np.max(prediction)
36
+
37
+ # Function to compute gradients
38
+ def get_gradients(image, label):
39
+ image = tf.convert_to_tensor(image.reshape((1, 28, 28, 1)), dtype=tf.float32)
40
+ with tf.GradientTape() as tape:
41
+ tape.watch(image)
42
+ prediction = model(image)
43
+ loss = tf.keras.losses.categorical_crossentropy([label], prediction)
44
+ gradients = tape.gradient(loss, image)
45
+ return gradients.numpy().reshape(28, 28)
46
+
47
+ # Function to progressively mask image and observe changes
48
+ def progressively_mask_image(image, steps=100, increment=5):
49
+ image = preprocess_image(image).reshape(28, 28)
50
+ label = np.argmax(model.predict(image.reshape(1, 28, 28, 1)))
51
+ gradients = get_gradients(image, to_categorical(label, 10))
52
+
53
+ modified_image = np.copy(image)
54
+ original_prediction = model.predict(image.reshape(1, 28, 28, 1))
55
+ original_label = np.argmax(original_prediction)
56
+
57
+ for i in range(1, steps + 1):
58
+ threshold = np.percentile(np.abs(gradients), 100 - i * increment)
59
+ mask = np.abs(gradients) > threshold
60
+ modified_image[mask] = 0
61
+ modified_prediction = model.predict(modified_image.reshape(1, 28, 28, 1))
62
+ predicted_label = np.argmax(modified_prediction)
63
+ if predicted_label != original_label:
64
+ break
65
+
66
+ return cv2.resize(modified_image, DIMS), original_label, predicted_label
67
+
68
+ # Gradio interface functions
69
+ def gradio_predict(image):
70
+ predicted_label, confidence = predict(image)
71
+ return f"Predicted Label: {predicted_label}, Confidence: {confidence:.4f}"
72
+
73
+ def gradio_mask(image, steps, increment):
74
+ modified_image, original_label, predicted_label = progressively_mask_image(image, steps, increment)
75
+ return modified_image, f"Original Label: {original_label}, New Label: {predicted_label}"
76
+
77
+ # Gradio interface
78
+ image_input = gr.Image(image_mode='L', label="Input Image")
79
+ steps_input = gr.Slider(minimum=1, maximum=100, label="Steps", step=1, value=100)
80
+ increment_input = gr.Slider(minimum=1, maximum=20, label="Increment", step=1, value=5)
81
+
82
+
83
+ gr.Interface(
84
+ fn=gradio_mask,
85
+ inputs=[image_input, steps_input, increment_input],
86
+ outputs=[
87
+ gr.Image(image_mode='L', label="Ouput Image"),
88
+ gr.Textbox(label="Prediction Details")
89
+ ],
90
+ title="Progressive Masking",
91
+ description="Upload an image of a digit and observe how masking affects the model's prediction.",
92
+ examples=mnist_examples,
93
+ allow_flagging="never"
94
+ ).launch()