AI-RESEARCHER-2024 commited on
Commit
f9dbf06
·
verified ·
1 Parent(s): 6d8e979

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -28
app.py CHANGED
@@ -3,7 +3,6 @@ import numpy as np
3
  import tensorflow as tf
4
  from tensorflow.keras.models import load_model
5
  from tensorflow.keras.utils import to_categorical
6
- from sklearn.metrics import confusion_matrix
7
  from tensorflow.keras.datasets import mnist
8
  import cv2
9
 
@@ -72,35 +71,176 @@ def gradio_mask(image, steps, increment):
72
  modified_image, original_label, predicted_label = progressively_mask_image(image, steps, increment)
73
  return modified_image, f"Original Label: {original_label}, New Label: {predicted_label}"
74
 
75
- def create_interface():
76
- card_html = """
77
- <div style="background-color: #f8f9fa; padding: 20px; border-radius: 10px; text-align: center; margin-bottom: 20px;">
78
- <h1 style="font-family: Arial, sans-serif; color: #333;">Attribution Based Confidence Metric for Neural Networks</h1>
79
- <h2 style="font-family: Arial, sans-serif; color: #555;">Steven Fernandes, Ph.D.</h2>
80
- </div>
81
- """
82
-
83
- image_input = gr.Image(image_mode='L', label="Input Image")
84
- steps_input = gr.Slider(minimum=1, maximum=100, label="Steps", step=1, value=100)
85
- increment_input = gr.Slider(minimum=1, maximum=20, label="Increment", step=1, value=5)
86
-
87
- with gr.Blocks() as demo:
88
- gr.HTML(card_html)
89
- gr.Interface(
90
- fn=gradio_mask,
91
- inputs=[image_input, steps_input, increment_input],
92
- outputs=[
93
- gr.Image(image_mode='L', label="Output Image"),
94
- gr.Textbox(label="Prediction Details")
95
- ],
96
- title="Progressive Masking",
97
- description="Upload an image of a digit and observe how masking affects the model's prediction.",
98
- examples=mnist_examples,
99
- allow_flagging="never"
100
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  def main():
103
- create_interface()
 
 
104
 
105
  if __name__ == "__main__":
106
  main()
 
3
  import tensorflow as tf
4
  from tensorflow.keras.models import load_model
5
  from tensorflow.keras.utils import to_categorical
 
6
  from tensorflow.keras.datasets import mnist
7
  import cv2
8
 
 
71
  modified_image, original_label, predicted_label = progressively_mask_image(image, steps, increment)
72
  return modified_image, f"Original Label: {original_label}, New Label: {predicted_label}"
73
 
74
+ class GradioInterface:
75
+ def __init__(self):
76
+ self.preloaded_examples = self.preload_examples()
77
+
78
+ def preload_examples(self):
79
+ preloaded = {}
80
+ for model_name, example_dir in Config.EXAMPLES.items():
81
+ examples = [os.path.join(example_dir, img) for img in os.listdir(example_dir)]
82
+ preloaded[model_name] = examples
83
+ return preloaded
84
+
85
+ def create_interface(self):
86
+ app_styles = """
87
+ <style>
88
+ /* Global Styles */
89
+ body, #root {
90
+ font-family: Helvetica, Arial, sans-serif;
91
+ background-color: #1a1a1a;
92
+ color: #fafafa;
93
+ }
94
+ /* Header Styles */
95
+ .app-header {
96
+ background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%);
97
+ padding: 24px;
98
+ border-radius: 8px;
99
+ margin-bottom: 24px;
100
+ text-align: center;
101
+ }
102
+ .app-title {
103
+ font-size: 48px;
104
+ margin: 0;
105
+ color: #fafafa;
106
+ }
107
+ .app-subtitle {
108
+ font-size: 24px;
109
+ margin: 8px 0 16px;
110
+ color: #fafafa;
111
+ }
112
+ .app-description {
113
+ font-size: 16px;
114
+ line-height: 1.6;
115
+ opacity: 0.8;
116
+ margin-bottom: 24px;
117
+ }
118
+ /* Button Styles */
119
+ .publication-links {
120
+ display: flex;
121
+ justify-content: center;
122
+ flex-wrap: wrap;
123
+ gap: 8px;
124
+ margin-bottom: 16px;
125
+ }
126
+ .publication-link {
127
+ display: inline-flex;
128
+ align-items: center;
129
+ padding: 8px 16px;
130
+ background-color: #333;
131
+ color: #fff !important;
132
+ text-decoration: none !important;
133
+ border-radius: 20px;
134
+ font-size: 14px;
135
+ transition: background-color 0.3s;
136
+ }
137
+ .publication-link:hover {
138
+ background-color: #555;
139
+ }
140
+ .publication-link i {
141
+ margin-right: 8px;
142
+ }
143
+ /* Content Styles */
144
+ .content-container {
145
+ background-color: #2a2a2a;
146
+ border-radius: 8px;
147
+ padding: 24px;
148
+ margin-bottom: 24px;
149
+ }
150
+ /* Image Styles */
151
+ .image-preview img {
152
+ max-width: 512px;
153
+ max-height: 512px;
154
+ margin: 0 auto;
155
+ border-radius: 4px;
156
+ display: block;
157
+ object-fit: contain;
158
+ }
159
+ /* Control Styles */
160
+ .control-panel {
161
+ background-color: #333;
162
+ padding: 16px;
163
+ border-radius: 8px;
164
+ margin-top: 16px;
165
+ }
166
+ /* Gradio Component Overrides */
167
+ .gr-button {
168
+ background-color: #4a4a4a;
169
+ color: #fff;
170
+ border: none;
171
+ border-radius: 4px;
172
+ padding: 8px 16px;
173
+ cursor: pointer;
174
+ transition: background-color 0.3s;
175
+ }
176
+ .gr-button:hover {
177
+ background-color: #5a5a5a;
178
+ }
179
+ .gr-input, .gr-dropdown {
180
+ background-color: #3a3a3a;
181
+ color: #fff;
182
+ border: 1px solid #4a4a4a;
183
+ border-radius: 4px;
184
+ padding: 8px;
185
+ }
186
+ .gr-form {
187
+ background-color: transparent;
188
+ }
189
+ .gr-panel {
190
+ border: none;
191
+ background-color: transparent;
192
+ }
193
+ /* Override any conflicting styles from Bulma */
194
+ .button.is-normal.is-rounded.is-dark {
195
+ color: #fff !important;
196
+ text-decoration: none !important;
197
+ }
198
+ </style>
199
+ """
200
+
201
+ header_html = f"""
202
+ <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css">
203
+ <link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.4/css/all.css">
204
+ {app_styles}
205
+ <div class="app-header">
206
+ <h1 class="app-title">Attribution Based Confidence Metric for Neural Networks</h1>
207
+ <h2 class="app-subtitle">Steven Fernandes, Ph.D.</h2>
208
+ </div>
209
+ """
210
+
211
+ js_func = """
212
+ function refresh() {
213
+ const url = new URL(window.location);
214
+ if (url.searchParams.get('__theme') !== 'dark') {
215
+ url.searchParams.set('__theme', 'dark');
216
+ window.location.href = url.href;
217
+ }
218
+ }
219
+ """
220
+
221
+ with gr.Blocks(js=js_func, theme=gr.themes.Default()) as demo:
222
+ gr.HTML(header_html)
223
+ with gr.Row(elem_classes="content-container"):
224
+ with gr.Column():
225
+ input_image = gr.Image(label="Input Image", type="pil", format="png", elem_classes="image-preview")
226
+ steps_input = gr.Slider(minimum=1, maximum=100, label="Steps", step=1, value=100)
227
+ increment_input = gr.Slider(minimum=1, maximum=20, label="Increment", step=1, value=5)
228
+ with gr.Column():
229
+ result = gr.Image(label="Result", elem_classes="image-preview")
230
+ run_button = gr.Button("Run", elem_classes="gr-button")
231
+
232
+ run_button.click(
233
+ fn=gradio_mask,
234
+ inputs=[input_image, steps_input, increment_input],
235
+ outputs=[result, gr.Textbox(label="Prediction Details")],
236
+ )
237
+
238
+ return demo
239
 
240
  def main():
241
+ interface = GradioInterface()
242
+ demo = interface.create_interface()
243
+ demo.launch(debug=True)
244
 
245
  if __name__ == "__main__":
246
  main()