cmckinle commited on
Commit
e88ec7e
·
verified ·
1 Parent(s): 3ecd307

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -60
app.py CHANGED
@@ -1,15 +1,22 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline
4
- import os
5
- from numpy import exp
6
- import pandas as pd
 
 
 
7
  from PIL import Image
8
- import urllib.request
9
  import uuid
10
- uid = uuid.uuid4()
 
 
 
 
 
11
 
12
- # Reordered models as requested
13
  models = [
14
  "umm-maybe/AI-image-detector",
15
  "Organika/sdxl-detector",
@@ -21,11 +28,78 @@ pipe1 = pipeline("image-classification", f"{models[1]}")
21
  pipe2 = pipeline("image-classification", f"{models[2]}")
22
 
23
  fin_sum = []
 
24
 
 
25
  def softmax(vector):
26
  e = exp(vector - vector.max()) # for numerical stability
27
  return e / e.sum()
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def image_classifier0(image):
30
  labels = ["AI", "Real"]
31
  outputs = pipe0(image)
@@ -70,8 +144,8 @@ def aiornot0(image):
70
  html_out = f"""
71
  <h1>This image is likely: {label}</h1><br><h3>
72
  Probabilities:<br>
73
- Real: {float(px[1][0])}<br>
74
- AI: {float(px[0][0])}"""
75
 
76
  results = {
77
  "Real": float(px[1][0]),
@@ -97,8 +171,8 @@ def aiornot1(image):
97
  html_out = f"""
98
  <h1>This image is likely: {label}</h1><br><h3>
99
  Probabilities:<br>
100
- Real: {float(px[1][0])}<br>
101
- AI: {float(px[0][0])}"""
102
 
103
  results = {
104
  "Real": float(px[1][0]),
@@ -124,8 +198,8 @@ def aiornot2(image):
124
  html_out = f"""
125
  <h1>This image is likely: {label}</h1><br><h3>
126
  Probabilities:<br>
127
- Real: {float(px[1][0])}<br>
128
- AI: {float(px[0][0])}"""
129
 
130
  results = {
131
  "Real": float(px[1][0]),
@@ -149,8 +223,8 @@ def tot_prob():
149
  fin_out = sum([result["Real"] for result in fin_sum]) / len(fin_sum)
150
  fin_sub = 1 - fin_out
151
  out = {
152
- "Real": f"{fin_out}",
153
- "AI": f"{fin_sub}"
154
  }
155
  return out
156
  except Exception as e:
@@ -167,50 +241,56 @@ def upd(image):
167
  out = Image.open(f"{rand_im}-vid_tmp_proc.png")
168
  return out
169
 
 
170
  with gr.Blocks() as app:
171
- gr.Markdown("""<center><h1>AI Image Detector<br><h4>(Test Demo - accuracy varies by model)""")
172
- with gr.Column():
173
- inp = gr.Image(type='pil')
174
- in_url = gr.Textbox(label="Image URL")
175
- with gr.Row():
176
- load_btn = gr.Button("Load URL")
177
- btn = gr.Button("Detect AI")
178
- mes = gr.HTML("""""")
179
-
180
- with gr.Group():
181
- with gr.Row():
182
- fin = gr.Label(label="Final Probability", visible=False)
183
- with gr.Row():
184
- # Updated model names
185
- with gr.Box():
186
- lab0 = gr.HTML(f"""<b>Testing on Original Model: <a href='https://huggingface.co/{models[0]}'>{models[0]}</a></b>""")
187
- nun0 = gr.HTML("""""")
188
- with gr.Box():
189
- lab1 = gr.HTML(f"""<b>Testing on SDXL Fine Tuned Model: <a href='https://huggingface.co/{models[1]}'>{models[1]}</a></b>""")
190
- nun1 = gr.HTML("""""")
191
- with gr.Box():
192
- lab2 = gr.HTML(f"""<b>Testing on SDXL and Flux Fine Tuned Model: <a href='https://huggingface.co/{models[2]}'>{models[2]}</a></b>""")
193
- nun2 = gr.HTML("""""")
194
- with gr.Row():
195
- with gr.Box():
196
- n_out0 = gr.Label(label="Output")
197
- outp0 = gr.HTML("""""")
198
- with gr.Box():
199
- n_out1 = gr.Label(label="Output")
200
- outp1 = gr.HTML("""""")
201
- with gr.Box():
202
- n_out2 = gr.Label(label="Output")
203
- outp2 = gr.HTML("""""")
204
-
205
- btn.click(fin_clear, None, fin, show_progress=False)
206
- load_btn.click(load_url, in_url, [inp, mes])
207
-
208
- btn.click(aiornot0, [inp], [outp0, n_out0]).then(tot_prob, None, fin, show_progress=False)
209
- btn.click(aiornot1, [inp], [outp1, n_out1]).then(tot_prob, None, fin, show_progress=False)
210
- btn.click(aiornot2, [inp], [outp2, n_out2]).then(tot_prob, None, fin, show_progress=False)
211
-
212
- btn.click(image_classifier0, [inp], [n_out0]).then(tot_prob, None, fin, show_progress=False)
213
- btn.click(image_classifier1, [inp], [n_out1]).then(tot_prob, None, fin, show_progress=False)
214
- btn.click(image_classifier2, [inp], [n_out2]).then(tot_prob, None, fin, show_progress=False)
215
-
216
- app.launch(show_api=False, max_threads=24)
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline
4
+ import os
5
+ import zipfile
6
+ import shutil
7
+ import matplotlib.pyplot as plt
8
+ from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report, roc_curve, auc
9
+ from tqdm import tqdm
10
  from PIL import Image
 
11
  import uuid
12
+ import tempfile
13
+ import pandas as pd
14
+ from numpy import exp
15
+ import numpy as np
16
+ from sklearn.metrics import ConfusionMatrixDisplay
17
+ import urllib.request
18
 
19
+ # Define models
20
  models = [
21
  "umm-maybe/AI-image-detector",
22
  "Organika/sdxl-detector",
 
28
  pipe2 = pipeline("image-classification", f"{models[2]}")
29
 
30
  fin_sum = []
31
+ uid = uuid.uuid4()
32
 
33
+ # Softmax function
34
  def softmax(vector):
35
  e = exp(vector - vector.max()) # for numerical stability
36
  return e / e.sum()
37
 
38
+ # Function to extract images from zip
39
+ def extract_zip(zip_file):
40
+ temp_dir = tempfile.mkdtemp() # Temporary directory
41
+ with zipfile.ZipFile(zip_file, 'r') as z:
42
+ z.extractall(temp_dir)
43
+ return temp_dir
44
+
45
+ # Function to classify images in a folder
46
+ def classify_images(image_dir, model_pipeline):
47
+ images = []
48
+ labels = []
49
+ preds = []
50
+ for folder_name, ground_truth_label in [('real', 1), ('ai', 0)]:
51
+ folder_path = os.path.join(image_dir, folder_name)
52
+ if not os.path.exists(folder_path):
53
+ continue
54
+ for img_name in os.listdir(folder_path):
55
+ img_path = os.path.join(folder_path, img_name)
56
+ try:
57
+ img = Image.open(img_path).convert("RGB")
58
+ pred = model_pipeline(img)
59
+ pred_label = np.argmax([x['score'] for x in pred])
60
+ preds.append(pred_label)
61
+ labels.append(ground_truth_label)
62
+ images.append(img_name)
63
+ except Exception as e:
64
+ print(f"Error processing image {img_name}: {e}")
65
+ return labels, preds, images
66
+
67
+ # Function to generate evaluation metrics
68
+ def evaluate_model(labels, preds):
69
+ cm = confusion_matrix(labels, preds)
70
+ accuracy = accuracy_score(labels, preds)
71
+ roc_score = roc_auc_score(labels, preds)
72
+ report = classification_report(labels, preds)
73
+ fpr, tpr, _ = roc_curve(labels, preds)
74
+ roc_auc = auc(fpr, tpr)
75
+
76
+ fig, ax = plt.subplots()
77
+ disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["AI", "Real"])
78
+ disp.plot(cmap=plt.cm.Blues, ax=ax)
79
+ plt.close(fig)
80
+
81
+ fig_roc, ax_roc = plt.subplots()
82
+ ax_roc.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
83
+ ax_roc.plot([0, 1], [0, 1], color='gray', linestyle='--')
84
+ ax_roc.set_xlim([0.0, 1.0])
85
+ ax_roc.set_ylim([0.0, 1.05])
86
+ ax_roc.set_xlabel('False Positive Rate')
87
+ ax_roc.set_ylabel('True Positive Rate')
88
+ ax_roc.set_title('Receiver Operating Characteristic (ROC) Curve')
89
+ ax_roc.legend(loc="lower right")
90
+ plt.close(fig_roc)
91
+
92
+ return accuracy, roc_score, report, fig, fig_roc
93
+
94
+ # Gradio function for batch image processing
95
+ def process_zip(zip_file):
96
+ extracted_dir = extract_zip(zip_file.name)
97
+ labels, preds, images = classify_images(extracted_dir, pipe0) # You can switch to pipe1 or pipe2
98
+ accuracy, roc_score, report, cm_fig, roc_fig = evaluate_model(labels, preds)
99
+ shutil.rmtree(extracted_dir) # Clean up extracted files
100
+ return accuracy, roc_score, report, cm_fig, roc_fig
101
+
102
+ # Single image classification functions
103
  def image_classifier0(image):
104
  labels = ["AI", "Real"]
105
  outputs = pipe0(image)
 
144
  html_out = f"""
145
  <h1>This image is likely: {label}</h1><br><h3>
146
  Probabilities:<br>
147
+ Real: {float(px[1][0]):.4f}<br>
148
+ AI: {float(px[0][0]):.4f}"""
149
 
150
  results = {
151
  "Real": float(px[1][0]),
 
171
  html_out = f"""
172
  <h1>This image is likely: {label}</h1><br><h3>
173
  Probabilities:<br>
174
+ Real: {float(px[1][0]):.4f}<br>
175
+ AI: {float(px[0][0]):.4f}"""
176
 
177
  results = {
178
  "Real": float(px[1][0]),
 
198
  html_out = f"""
199
  <h1>This image is likely: {label}</h1><br><h3>
200
  Probabilities:<br>
201
+ Real: {float(px[1][0]):.4f}<br>
202
+ AI: {float(px[0][0]):.4f}"""
203
 
204
  results = {
205
  "Real": float(px[1][0]),
 
223
  fin_out = sum([result["Real"] for result in fin_sum]) / len(fin_sum)
224
  fin_sub = 1 - fin_out
225
  out = {
226
+ "Real": f"{fin_out:.4f}",
227
+ "AI": f"{fin_sub:.4f}"
228
  }
229
  return out
230
  except Exception as e:
 
241
  out = Image.open(f"{rand_im}-vid_tmp_proc.png")
242
  return out
243
 
244
+ # Set up Gradio app
245
  with gr.Blocks() as app:
246
+ gr.Markdown("""<center><h1>AI Image Detector<br><h4>(Test Demo - accuracy varies by model)</h4></h1></center>""")
247
+
248
+ with gr.Tabs():
249
+ # Tab for single image detection
250
+ with gr.Tab("Single Image Detection"):
251
+ with gr.Column():
252
+ inp = gr.Image(type='pil')
253
+ in_url = gr.Textbox(label="Image URL")
254
+ with gr.Row():
255
+ load_btn = gr.Button("Load URL")
256
+ btn = gr.Button("Detect AI")
257
+ mes = gr.HTML("""""")
258
+
259
+ with gr.Group():
260
+ with gr.Row():
261
+ fin = gr.Label(label="Final Probability")
262
+ with gr.Row():
263
+ for i, model in enumerate(models):
264
+ with gr.Box():
265
+ gr.HTML(f"""<b>Testing on Model {i}: <a href='https://huggingface.co/{model}'>{model}</a></b>""")
266
+ globals()[f'outp{i}'] = gr.HTML("""""")
267
+ globals()[f'n_out{i}'] = gr.Label(label="Output")
268
+
269
+ btn.click(fin_clear, None, fin, show_progress=False)
270
+ load_btn.click(load_url, in_url, [inp, mes])
271
+
272
+ btn.click(aiornot0, [inp], [outp0, n_out0]).then(
273
+ aiornot1, [inp], [outp1, n_out1]).then(
274
+ aiornot2, [inp], [outp2, n_out2]).then(
275
+ tot_prob, None, fin, show_progress=False)
276
+
277
+ btn.click(image_classifier0, [inp], [n_out0]).then(
278
+ image_classifier1, [inp], [n_out1]).then(
279
+ image_classifier2, [inp], [n_out2]).then(
280
+ tot_prob, None, fin, show_progress=False)
281
+
282
+ # Tab for batch processing
283
+ with gr.Tab("Batch Image Processing"):
284
+ zip_file = gr.File(label="Upload Zip (two folders: real, ai)")
285
+ output_acc = gr.Label(label="Accuracy")
286
+ output_roc = gr.Label(label="ROC Score")
287
+ output_report = gr.Textbox(label="Classification Report", lines=10)
288
+ output_cm = gr.Plot(label="Confusion Matrix")
289
+ output_roc_plot = gr.Plot(label="ROC Curve")
290
+
291
+ batch_btn = gr.Button("Process Batch")
292
+
293
+ # Connect batch processing
294
+ batch_btn.click(process_zip, zip_file, [output_acc, output_roc, output_report, output_cm, output_roc_plot])
295
+
296
+ app.launch(show_api=False, max_threads=24)