cmckinle commited on
Commit
1d7c27a
·
verified ·
1 Parent(s): 542febe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -109
app.py CHANGED
@@ -53,46 +53,7 @@ def process_zip(zip_file):
53
 
54
  z.extractall(temp_dir)
55
 
56
- labels, preds, images = [], [], []
57
- false_positives, false_negatives = [], []
58
- detector = AIDetector()
59
-
60
- total_images = sum(len(files) for _, _, files in os.walk(temp_dir))
61
- processed_images = 0
62
-
63
- for folder_name, ground_truth_label in [('real', 1), ('ai', 0)]:
64
- folder_path = os.path.join(temp_dir, folder_name)
65
- if not os.path.exists(folder_path):
66
- raise ValueError(f"Folder not found: {folder_path}")
67
-
68
- for img_name in os.listdir(folder_path):
69
- img_path = os.path.join(folder_path, img_name)
70
- try:
71
- with Image.open(img_path).convert("RGB") as img:
72
- _, prediction = detector.predict(img)
73
-
74
- pred_label = 0 if prediction["AI"] > prediction["Real"] else 1
75
-
76
- preds.append(pred_label)
77
- labels.append(ground_truth_label)
78
- images.append(img_name)
79
-
80
- # Collect false positives and false negatives with image data
81
- if pred_label != ground_truth_label:
82
- with open(img_path, "rb") as img_file:
83
- img_data = base64.b64encode(img_file.read()).decode()
84
- if pred_label == 1 and ground_truth_label == 0:
85
- false_positives.append((img_name, img_data))
86
- elif pred_label == 0 and ground_truth_label == 1:
87
- false_negatives.append((img_name, img_data))
88
-
89
- except Exception as e:
90
- print(f"Error processing image {img_name}: {e}")
91
-
92
- processed_images += 1
93
- gr.Progress(processed_images / total_images)
94
-
95
- return evaluate_model(labels, preds, false_positives, false_negatives)
96
 
97
  except Exception as e:
98
  raise gr.Error(f"Error processing zip file: {str(e)}")
@@ -100,11 +61,100 @@ def process_zip(zip_file):
100
  finally:
101
  shutil.rmtree(temp_dir)
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  def format_classification_report(labels, preds):
104
- # Convert the report string to a dictionary
105
  report_dict = classification_report(labels, preds, output_dict=True)
106
 
107
- # Create an HTML table with updated CSS
108
  html = """
109
  <style>
110
  .report-table {
@@ -170,7 +220,6 @@ def format_classification_report(labels, preds):
170
  </tr>
171
  """
172
 
173
- # Add rows for each class
174
  for class_name in ['0', '1']:
175
  html += f"""
176
  <tr>
@@ -182,7 +231,6 @@ def format_classification_report(labels, preds):
182
  </tr>
183
  """
184
 
185
- # Add summary rows
186
  html += f"""
187
  <tr>
188
  <td>Accuracy</td>
@@ -207,33 +255,9 @@ def format_classification_report(labels, preds):
207
  """
208
 
209
  return html
210
-
211
- def evaluate_model(labels, preds, false_positives, false_negatives):
212
- cm = confusion_matrix(labels, preds)
213
- accuracy = accuracy_score(labels, preds)
214
- roc_score = roc_auc_score(labels, preds)
215
- report_html = format_classification_report(labels, preds)
216
- fpr, tpr, _ = roc_curve(labels, preds)
217
- roc_auc = auc(fpr, tpr)
218
-
219
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
220
-
221
- ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=LABELS).plot(cmap=plt.cm.Blues, ax=ax1)
222
- ax1.set_title("Confusion Matrix")
223
-
224
- ax2.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
225
- ax2.plot([0, 1], [0, 1], color='gray', linestyle='--')
226
- ax2.set_xlim([0.0, 1.0])
227
- ax2.set_ylim([0.0, 1.05])
228
- ax2.set_xlabel('False Positive Rate')
229
- ax2.set_ylabel('True Positive Rate')
230
- ax2.set_title('ROC Curve')
231
- ax2.legend(loc="lower right")
232
-
233
- plt.tight_layout()
234
 
235
- # Create HTML for false positives and negatives with images
236
- fp_fn_html = """
237
  <style>
238
  .image-grid {
239
  display: flex;
@@ -252,29 +276,29 @@ def evaluate_model(labels, preds, false_positives, false_negatives):
252
  </style>
253
  """
254
 
255
- fp_fn_html += "<h3>False Positives (AI images classified as Real):</h3>"
256
- fp_fn_html += '<div class="image-grid">'
257
  for img_name, img_data in false_positives:
258
- fp_fn_html += f'''
259
  <div class="image-item">
260
  <img src="data:image/jpeg;base64,{img_data}" alt="{img_name}">
261
  <p>{img_name}</p>
262
  </div>
263
  '''
264
- fp_fn_html += '</div>'
265
 
266
- fp_fn_html += "<h3>False Negatives (Real images classified as AI):</h3>"
267
- fp_fn_html += '<div class="image-grid">'
268
  for img_name, img_data in false_negatives:
269
- fp_fn_html += f'''
270
  <div class="image-item">
271
  <img src="data:image/jpeg;base64,{img_data}" alt="{img_name}">
272
  <p>{img_name}</p>
273
  </div>
274
  '''
275
- fp_fn_html += '</div>'
276
 
277
- return accuracy, roc_score, report_html, fig, fp_fn_html
278
 
279
  def load_url(url):
280
  try:
@@ -309,13 +333,28 @@ def create_gradio_interface():
309
  output_label = gr.Label(label="Output")
310
 
311
  with gr.Tab("Batch Image Processing"):
312
- zip_file = gr.File(
313
- label="Upload Zip (must contain 'real' and 'ai' folders)",
314
- file_types=[".zip"],
315
- file_count="single",
316
- max_file_size=1024 # 1024 MB (1 GB)
317
- )
318
- batch_btn = gr.Button("Process Batch", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
  with gr.Group():
321
  gr.Markdown(f"### Results for {MODEL_NAME}")
@@ -332,28 +371,7 @@ def create_gradio_interface():
332
  [output_html, output_label]
333
  )
334
 
335
- def enable_batch_btn(file):
336
  return gr.Button.update(interactive=file is not None)
337
 
338
- zip_file.upload(
339
- enable_batch_btn,
340
- zip_file,
341
- batch_btn
342
- )
343
-
344
- batch_btn.click(
345
- process_zip,
346
- zip_file,
347
- [output_acc, output_roc, output_report, output_plots, output_fp_fn],
348
- api_name="batch_process"
349
- )
350
-
351
- return app
352
-
353
- if __name__ == "__main__":
354
- app = create_gradio_interface()
355
- app.launch(
356
- show_api=False,
357
- max_threads=24,
358
- show_error=True
359
- )
 
53
 
54
  z.extractall(temp_dir)
55
 
56
+ return evaluate_model(temp_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  except Exception as e:
59
  raise gr.Error(f"Error processing zip file: {str(e)}")
 
61
  finally:
62
  shutil.rmtree(temp_dir)
63
 
64
+ def process_files(ai_files, real_files):
65
+ temp_dir = tempfile.mkdtemp()
66
+ try:
67
+ # Process AI files
68
+ ai_folder = os.path.join(temp_dir, 'ai')
69
+ os.makedirs(ai_folder)
70
+ for file in ai_files:
71
+ shutil.copy(file.name, os.path.join(ai_folder, os.path.basename(file.name)))
72
+
73
+ # Process Real files
74
+ real_folder = os.path.join(temp_dir, 'real')
75
+ os.makedirs(real_folder)
76
+ for file in real_files:
77
+ shutil.copy(file.name, os.path.join(real_folder, os.path.basename(file.name)))
78
+
79
+ return evaluate_model(temp_dir)
80
+ except Exception as e:
81
+ raise gr.Error(f"Error processing individual files: {str(e)}")
82
+ finally:
83
+ shutil.rmtree(temp_dir)
84
+
85
+ def evaluate_model(temp_dir):
86
+ labels, preds, images = [], [], []
87
+ false_positives, false_negatives = [], []
88
+ detector = AIDetector()
89
+
90
+ total_images = sum(len(files) for _, _, files in os.walk(temp_dir))
91
+ processed_images = 0
92
+
93
+ for folder_name, ground_truth_label in [('real', 1), ('ai', 0)]:
94
+ folder_path = os.path.join(temp_dir, folder_name)
95
+ if not os.path.exists(folder_path):
96
+ raise ValueError(f"Folder not found: {folder_path}")
97
+
98
+ for img_name in os.listdir(folder_path):
99
+ img_path = os.path.join(folder_path, img_name)
100
+ try:
101
+ with Image.open(img_path).convert("RGB") as img:
102
+ _, prediction = detector.predict(img)
103
+
104
+ pred_label = 0 if prediction["AI"] > prediction["Real"] else 1
105
+
106
+ preds.append(pred_label)
107
+ labels.append(ground_truth_label)
108
+ images.append(img_name)
109
+
110
+ # Collect false positives and false negatives with image data
111
+ if pred_label != ground_truth_label:
112
+ with open(img_path, "rb") as img_file:
113
+ img_data = base64.b64encode(img_file.read()).decode()
114
+ if pred_label == 1 and ground_truth_label == 0:
115
+ false_positives.append((img_name, img_data))
116
+ elif pred_label == 0 and ground_truth_label == 1:
117
+ false_negatives.append((img_name, img_data))
118
+
119
+ except Exception as e:
120
+ print(f"Error processing image {img_name}: {e}")
121
+
122
+ processed_images += 1
123
+ gr.Progress(processed_images / total_images)
124
+
125
+ return calculate_metrics(labels, preds, false_positives, false_negatives)
126
+
127
+ def calculate_metrics(labels, preds, false_positives, false_negatives):
128
+ cm = confusion_matrix(labels, preds)
129
+ accuracy = accuracy_score(labels, preds)
130
+ roc_score = roc_auc_score(labels, preds)
131
+ report_html = format_classification_report(labels, preds)
132
+ fpr, tpr, _ = roc_curve(labels, preds)
133
+ roc_auc = auc(fpr, tpr)
134
+
135
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
136
+
137
+ ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=LABELS).plot(cmap=plt.cm.Blues, ax=ax1)
138
+ ax1.set_title("Confusion Matrix")
139
+
140
+ ax2.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
141
+ ax2.plot([0, 1], [0, 1], color='gray', linestyle='--')
142
+ ax2.set_xlim([0.0, 1.0])
143
+ ax2.set_ylim([0.0, 1.05])
144
+ ax2.set_xlabel('False Positive Rate')
145
+ ax2.set_ylabel('True Positive Rate')
146
+ ax2.set_title('ROC Curve')
147
+ ax2.legend(loc="lower right")
148
+
149
+ plt.tight_layout()
150
+
151
+ fp_fn_html = create_fp_fn_html(false_positives, false_negatives)
152
+
153
+ return accuracy, roc_score, report_html, fig, fp_fn_html
154
+
155
  def format_classification_report(labels, preds):
 
156
  report_dict = classification_report(labels, preds, output_dict=True)
157
 
 
158
  html = """
159
  <style>
160
  .report-table {
 
220
  </tr>
221
  """
222
 
 
223
  for class_name in ['0', '1']:
224
  html += f"""
225
  <tr>
 
231
  </tr>
232
  """
233
 
 
234
  html += f"""
235
  <tr>
236
  <td>Accuracy</td>
 
255
  """
256
 
257
  return html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
+ def create_fp_fn_html(false_positives, false_negatives):
260
+ html = """
261
  <style>
262
  .image-grid {
263
  display: flex;
 
276
  </style>
277
  """
278
 
279
+ html += "<h3>False Positives (AI images classified as Real):</h3>"
280
+ html += '<div class="image-grid">'
281
  for img_name, img_data in false_positives:
282
+ html += f'''
283
  <div class="image-item">
284
  <img src="data:image/jpeg;base64,{img_data}" alt="{img_name}">
285
  <p>{img_name}</p>
286
  </div>
287
  '''
288
+ html += '</div>'
289
 
290
+ html += "<h3>False Negatives (Real images classified as AI):</h3>"
291
+ html += '<div class="image-grid">'
292
  for img_name, img_data in false_negatives:
293
+ html += f'''
294
  <div class="image-item">
295
  <img src="data:image/jpeg;base64,{img_data}" alt="{img_name}">
296
  <p>{img_name}</p>
297
  </div>
298
  '''
299
+ html += '</div>'
300
 
301
+ return html
302
 
303
  def load_url(url):
304
  try:
 
333
  output_label = gr.Label(label="Output")
334
 
335
  with gr.Tab("Batch Image Processing"):
336
+ with gr.Accordion("Upload Zip File (max 100MB)", open=False):
337
+ zip_file = gr.File(
338
+ label="Upload Zip (must contain 'real' and 'ai' folders)",
339
+ file_types=[".zip"],
340
+ file_count="single",
341
+ max_file_size=100 # 100 MB limit
342
+ )
343
+ zip_process_btn = gr.Button("Process Zip", interactive=False)
344
+
345
+ with gr.Accordion("Upload Individual Files (for datasets over 100MB)", open=False):
346
+ with gr.Row():
347
+ ai_files = gr.File(
348
+ label="Upload AI Images",
349
+ file_types=["image"],
350
+ file_count="multiple"
351
+ )
352
+ real_files = gr.File(
353
+ label="Upload Real Images",
354
+ file_types=["image"],
355
+ file_count="multiple"
356
+ )
357
+ individual_process_btn = gr.Button("Process Individual Files", interactive=False)
358
 
359
  with gr.Group():
360
  gr.Markdown(f"### Results for {MODEL_NAME}")
 
371
  [output_html, output_label]
372
  )
373
 
374
+ def enable_zip_btn(file):
375
  return gr.Button.update(interactive=file is not None)
376
 
377
+ def enable_individual_btn