cmckinle commited on
Commit
12fbe49
·
verified ·
1 Parent(s): 58af1bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -4
app.py CHANGED
@@ -10,6 +10,8 @@ from PIL import Image
10
  import tempfile
11
  import numpy as np
12
  import urllib.request
 
 
13
 
14
  MODEL_NAME = "cmckinle/sdxl-flux-detector"
15
  LABELS = ["AI", "Real"]
@@ -45,6 +47,7 @@ def process_zip(zip_file):
45
  z.extractall(temp_dir)
46
 
47
  labels, preds, images = [], [], []
 
48
  detector = AIDetector()
49
 
50
  for folder_name, ground_truth_label in [('real', 1), ('ai', 0)]:
@@ -62,11 +65,20 @@ def process_zip(zip_file):
62
  preds.append(pred_label)
63
  labels.append(ground_truth_label)
64
  images.append(img_name)
 
 
 
 
 
 
 
 
 
65
  except Exception as e:
66
  print(f"Error processing image {img_name}: {e}")
67
 
68
  shutil.rmtree(temp_dir)
69
- return evaluate_model(labels, preds)
70
 
71
  def format_classification_report(labels, preds):
72
  # Convert the report string to a dictionary
@@ -166,7 +178,7 @@ def format_classification_report(labels, preds):
166
 
167
  return html
168
 
169
- def evaluate_model(labels, preds):
170
  cm = confusion_matrix(labels, preds)
171
  accuracy = accuracy_score(labels, preds)
172
  roc_score = roc_auc_score(labels, preds)
@@ -190,7 +202,49 @@ def evaluate_model(labels, preds):
190
 
191
  plt.tight_layout()
192
 
193
- return accuracy, roc_score, report_html, fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  def load_url(url):
196
  try:
@@ -234,6 +288,7 @@ def create_gradio_interface():
234
  output_roc = gr.Label(label="ROC Score")
235
  output_report = gr.HTML(label="Classification Report")
236
  output_plots = gr.Plot(label="Confusion Matrix and ROC Curve")
 
237
 
238
  load_btn.click(load_url, in_url, [inp, message])
239
  btn.click(
@@ -245,7 +300,7 @@ def create_gradio_interface():
245
  batch_btn.click(
246
  process_zip,
247
  zip_file,
248
- [output_acc, output_roc, output_report, output_plots]
249
  )
250
 
251
  return app
 
10
  import tempfile
11
  import numpy as np
12
  import urllib.request
13
+ import base64
14
+ from io import BytesIO
15
 
16
  MODEL_NAME = "cmckinle/sdxl-flux-detector"
17
  LABELS = ["AI", "Real"]
 
47
  z.extractall(temp_dir)
48
 
49
  labels, preds, images = [], [], []
50
+ false_positives, false_negatives = [], []
51
  detector = AIDetector()
52
 
53
  for folder_name, ground_truth_label in [('real', 1), ('ai', 0)]:
 
65
  preds.append(pred_label)
66
  labels.append(ground_truth_label)
67
  images.append(img_name)
68
+
69
+ # Collect false positives and false negatives with image data
70
+ if pred_label != ground_truth_label:
71
+ img_data = base64.b64encode(open(img_path, "rb").read()).decode()
72
+ if pred_label == 1 and ground_truth_label == 0:
73
+ false_positives.append((img_name, img_data))
74
+ elif pred_label == 0 and ground_truth_label == 1:
75
+ false_negatives.append((img_name, img_data))
76
+
77
  except Exception as e:
78
  print(f"Error processing image {img_name}: {e}")
79
 
80
  shutil.rmtree(temp_dir)
81
+ return evaluate_model(labels, preds, false_positives, false_negatives)
82
 
83
  def format_classification_report(labels, preds):
84
  # Convert the report string to a dictionary
 
178
 
179
  return html
180
 
181
+ def evaluate_model(labels, preds, false_positives, false_negatives):
182
  cm = confusion_matrix(labels, preds)
183
  accuracy = accuracy_score(labels, preds)
184
  roc_score = roc_auc_score(labels, preds)
 
202
 
203
  plt.tight_layout()
204
 
205
+ # Create HTML for false positives and negatives with images
206
+ fp_fn_html = """
207
+ <style>
208
+ .image-grid {
209
+ display: flex;
210
+ flex-wrap: wrap;
211
+ gap: 10px;
212
+ }
213
+ .image-item {
214
+ display: flex;
215
+ flex-direction: column;
216
+ align-items: center;
217
+ }
218
+ .image-item img {
219
+ max-width: 200px;
220
+ max-height: 200px;
221
+ }
222
+ </style>
223
+ """
224
+
225
+ fp_fn_html += "<h3>False Positives (AI images classified as Real):</h3>"
226
+ fp_fn_html += '<div class="image-grid">'
227
+ for img_name, img_data in false_positives:
228
+ fp_fn_html += f'''
229
+ <div class="image-item">
230
+ <img src="data:image/jpeg;base64,{img_data}" alt="{img_name}">
231
+ <p>{img_name}</p>
232
+ </div>
233
+ '''
234
+ fp_fn_html += '</div>'
235
+
236
+ fp_fn_html += "<h3>False Negatives (Real images classified as AI):</h3>"
237
+ fp_fn_html += '<div class="image-grid">'
238
+ for img_name, img_data in false_negatives:
239
+ fp_fn_html += f'''
240
+ <div class="image-item">
241
+ <img src="data:image/jpeg;base64,{img_data}" alt="{img_name}">
242
+ <p>{img_name}</p>
243
+ </div>
244
+ '''
245
+ fp_fn_html += '</div>'
246
+
247
+ return accuracy, roc_score, report_html, fig, fp_fn_html
248
 
249
  def load_url(url):
250
  try:
 
288
  output_roc = gr.Label(label="ROC Score")
289
  output_report = gr.HTML(label="Classification Report")
290
  output_plots = gr.Plot(label="Confusion Matrix and ROC Curve")
291
+ output_fp_fn = gr.HTML(label="False Positives and Negatives")
292
 
293
  load_btn.click(load_url, in_url, [inp, message])
294
  btn.click(
 
300
  batch_btn.click(
301
  process_zip,
302
  zip_file,
303
+ [output_acc, output_roc, output_report, output_plots, output_fp_fn]
304
  )
305
 
306
  return app