Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,15 +1,22 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline
|
4 |
-
import os
|
5 |
-
|
6 |
-
import
|
|
|
|
|
|
|
7 |
from PIL import Image
|
8 |
-
import urllib.request
|
9 |
import uuid
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
#
|
13 |
models = [
|
14 |
"umm-maybe/AI-image-detector",
|
15 |
"Organika/sdxl-detector",
|
@@ -21,11 +28,78 @@ pipe1 = pipeline("image-classification", f"{models[1]}")
|
|
21 |
pipe2 = pipeline("image-classification", f"{models[2]}")
|
22 |
|
23 |
fin_sum = []
|
|
|
24 |
|
|
|
25 |
def softmax(vector):
|
26 |
e = exp(vector - vector.max()) # for numerical stability
|
27 |
return e / e.sum()
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
def image_classifier0(image):
|
30 |
labels = ["AI", "Real"]
|
31 |
outputs = pipe0(image)
|
@@ -70,8 +144,8 @@ def aiornot0(image):
|
|
70 |
html_out = f"""
|
71 |
<h1>This image is likely: {label}</h1><br><h3>
|
72 |
Probabilities:<br>
|
73 |
-
Real: {float(px[1][0])}<br>
|
74 |
-
AI: {float(px[0][0])}"""
|
75 |
|
76 |
results = {
|
77 |
"Real": float(px[1][0]),
|
@@ -97,8 +171,8 @@ def aiornot1(image):
|
|
97 |
html_out = f"""
|
98 |
<h1>This image is likely: {label}</h1><br><h3>
|
99 |
Probabilities:<br>
|
100 |
-
Real: {float(px[1][0])}<br>
|
101 |
-
AI: {float(px[0][0])}"""
|
102 |
|
103 |
results = {
|
104 |
"Real": float(px[1][0]),
|
@@ -124,8 +198,8 @@ def aiornot2(image):
|
|
124 |
html_out = f"""
|
125 |
<h1>This image is likely: {label}</h1><br><h3>
|
126 |
Probabilities:<br>
|
127 |
-
Real: {float(px[1][0])}<br>
|
128 |
-
AI: {float(px[0][0])}"""
|
129 |
|
130 |
results = {
|
131 |
"Real": float(px[1][0]),
|
@@ -149,8 +223,8 @@ def tot_prob():
|
|
149 |
fin_out = sum([result["Real"] for result in fin_sum]) / len(fin_sum)
|
150 |
fin_sub = 1 - fin_out
|
151 |
out = {
|
152 |
-
"Real": f"{fin_out}",
|
153 |
-
"AI": f"{fin_sub}"
|
154 |
}
|
155 |
return out
|
156 |
except Exception as e:
|
@@ -167,50 +241,56 @@ def upd(image):
|
|
167 |
out = Image.open(f"{rand_im}-vid_tmp_proc.png")
|
168 |
return out
|
169 |
|
|
|
170 |
with gr.Blocks() as app:
|
171 |
-
gr.Markdown("""<center><h1>AI Image Detector<br><h4>(Test Demo - accuracy varies by model)""")
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
with gr.
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline
|
4 |
+
import os
|
5 |
+
import zipfile
|
6 |
+
import shutil
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report, roc_curve, auc
|
9 |
+
from tqdm import tqdm
|
10 |
from PIL import Image
|
|
|
11 |
import uuid
|
12 |
+
import tempfile
|
13 |
+
import pandas as pd
|
14 |
+
from numpy import exp
|
15 |
+
import numpy as np
|
16 |
+
from sklearn.metrics import ConfusionMatrixDisplay
|
17 |
+
import urllib.request
|
18 |
|
19 |
+
# Define models
|
20 |
models = [
|
21 |
"umm-maybe/AI-image-detector",
|
22 |
"Organika/sdxl-detector",
|
|
|
28 |
pipe2 = pipeline("image-classification", f"{models[2]}")
|
29 |
|
30 |
fin_sum = []
|
31 |
+
uid = uuid.uuid4()
|
32 |
|
33 |
+
# Softmax function
|
34 |
def softmax(vector):
|
35 |
e = exp(vector - vector.max()) # for numerical stability
|
36 |
return e / e.sum()
|
37 |
|
38 |
+
# Function to extract images from zip
|
39 |
+
def extract_zip(zip_file):
|
40 |
+
temp_dir = tempfile.mkdtemp() # Temporary directory
|
41 |
+
with zipfile.ZipFile(zip_file, 'r') as z:
|
42 |
+
z.extractall(temp_dir)
|
43 |
+
return temp_dir
|
44 |
+
|
45 |
+
# Function to classify images in a folder
|
46 |
+
def classify_images(image_dir, model_pipeline):
|
47 |
+
images = []
|
48 |
+
labels = []
|
49 |
+
preds = []
|
50 |
+
for folder_name, ground_truth_label in [('real', 1), ('ai', 0)]:
|
51 |
+
folder_path = os.path.join(image_dir, folder_name)
|
52 |
+
if not os.path.exists(folder_path):
|
53 |
+
continue
|
54 |
+
for img_name in os.listdir(folder_path):
|
55 |
+
img_path = os.path.join(folder_path, img_name)
|
56 |
+
try:
|
57 |
+
img = Image.open(img_path).convert("RGB")
|
58 |
+
pred = model_pipeline(img)
|
59 |
+
pred_label = np.argmax([x['score'] for x in pred])
|
60 |
+
preds.append(pred_label)
|
61 |
+
labels.append(ground_truth_label)
|
62 |
+
images.append(img_name)
|
63 |
+
except Exception as e:
|
64 |
+
print(f"Error processing image {img_name}: {e}")
|
65 |
+
return labels, preds, images
|
66 |
+
|
67 |
+
# Function to generate evaluation metrics
|
68 |
+
def evaluate_model(labels, preds):
|
69 |
+
cm = confusion_matrix(labels, preds)
|
70 |
+
accuracy = accuracy_score(labels, preds)
|
71 |
+
roc_score = roc_auc_score(labels, preds)
|
72 |
+
report = classification_report(labels, preds)
|
73 |
+
fpr, tpr, _ = roc_curve(labels, preds)
|
74 |
+
roc_auc = auc(fpr, tpr)
|
75 |
+
|
76 |
+
fig, ax = plt.subplots()
|
77 |
+
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["AI", "Real"])
|
78 |
+
disp.plot(cmap=plt.cm.Blues, ax=ax)
|
79 |
+
plt.close(fig)
|
80 |
+
|
81 |
+
fig_roc, ax_roc = plt.subplots()
|
82 |
+
ax_roc.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
|
83 |
+
ax_roc.plot([0, 1], [0, 1], color='gray', linestyle='--')
|
84 |
+
ax_roc.set_xlim([0.0, 1.0])
|
85 |
+
ax_roc.set_ylim([0.0, 1.05])
|
86 |
+
ax_roc.set_xlabel('False Positive Rate')
|
87 |
+
ax_roc.set_ylabel('True Positive Rate')
|
88 |
+
ax_roc.set_title('Receiver Operating Characteristic (ROC) Curve')
|
89 |
+
ax_roc.legend(loc="lower right")
|
90 |
+
plt.close(fig_roc)
|
91 |
+
|
92 |
+
return accuracy, roc_score, report, fig, fig_roc
|
93 |
+
|
94 |
+
# Gradio function for batch image processing
|
95 |
+
def process_zip(zip_file):
|
96 |
+
extracted_dir = extract_zip(zip_file.name)
|
97 |
+
labels, preds, images = classify_images(extracted_dir, pipe0) # You can switch to pipe1 or pipe2
|
98 |
+
accuracy, roc_score, report, cm_fig, roc_fig = evaluate_model(labels, preds)
|
99 |
+
shutil.rmtree(extracted_dir) # Clean up extracted files
|
100 |
+
return accuracy, roc_score, report, cm_fig, roc_fig
|
101 |
+
|
102 |
+
# Single image classification functions
|
103 |
def image_classifier0(image):
|
104 |
labels = ["AI", "Real"]
|
105 |
outputs = pipe0(image)
|
|
|
144 |
html_out = f"""
|
145 |
<h1>This image is likely: {label}</h1><br><h3>
|
146 |
Probabilities:<br>
|
147 |
+
Real: {float(px[1][0]):.4f}<br>
|
148 |
+
AI: {float(px[0][0]):.4f}"""
|
149 |
|
150 |
results = {
|
151 |
"Real": float(px[1][0]),
|
|
|
171 |
html_out = f"""
|
172 |
<h1>This image is likely: {label}</h1><br><h3>
|
173 |
Probabilities:<br>
|
174 |
+
Real: {float(px[1][0]):.4f}<br>
|
175 |
+
AI: {float(px[0][0]):.4f}"""
|
176 |
|
177 |
results = {
|
178 |
"Real": float(px[1][0]),
|
|
|
198 |
html_out = f"""
|
199 |
<h1>This image is likely: {label}</h1><br><h3>
|
200 |
Probabilities:<br>
|
201 |
+
Real: {float(px[1][0]):.4f}<br>
|
202 |
+
AI: {float(px[0][0]):.4f}"""
|
203 |
|
204 |
results = {
|
205 |
"Real": float(px[1][0]),
|
|
|
223 |
fin_out = sum([result["Real"] for result in fin_sum]) / len(fin_sum)
|
224 |
fin_sub = 1 - fin_out
|
225 |
out = {
|
226 |
+
"Real": f"{fin_out:.4f}",
|
227 |
+
"AI": f"{fin_sub:.4f}"
|
228 |
}
|
229 |
return out
|
230 |
except Exception as e:
|
|
|
241 |
out = Image.open(f"{rand_im}-vid_tmp_proc.png")
|
242 |
return out
|
243 |
|
244 |
+
# Set up Gradio app
|
245 |
with gr.Blocks() as app:
|
246 |
+
gr.Markdown("""<center><h1>AI Image Detector<br><h4>(Test Demo - accuracy varies by model)</h4></h1></center>""")
|
247 |
+
|
248 |
+
with gr.Tabs():
|
249 |
+
# Tab for single image detection
|
250 |
+
with gr.Tab("Single Image Detection"):
|
251 |
+
with gr.Column():
|
252 |
+
inp = gr.Image(type='pil')
|
253 |
+
in_url = gr.Textbox(label="Image URL")
|
254 |
+
with gr.Row():
|
255 |
+
load_btn = gr.Button("Load URL")
|
256 |
+
btn = gr.Button("Detect AI")
|
257 |
+
mes = gr.HTML("""""")
|
258 |
+
|
259 |
+
with gr.Group():
|
260 |
+
with gr.Row():
|
261 |
+
fin = gr.Label(label="Final Probability")
|
262 |
+
with gr.Row():
|
263 |
+
for i, model in enumerate(models):
|
264 |
+
with gr.Box():
|
265 |
+
gr.HTML(f"""<b>Testing on Model {i}: <a href='https://huggingface.co/{model}'>{model}</a></b>""")
|
266 |
+
globals()[f'outp{i}'] = gr.HTML("""""")
|
267 |
+
globals()[f'n_out{i}'] = gr.Label(label="Output")
|
268 |
+
|
269 |
+
btn.click(fin_clear, None, fin, show_progress=False)
|
270 |
+
load_btn.click(load_url, in_url, [inp, mes])
|
271 |
+
|
272 |
+
btn.click(aiornot0, [inp], [outp0, n_out0]).then(
|
273 |
+
aiornot1, [inp], [outp1, n_out1]).then(
|
274 |
+
aiornot2, [inp], [outp2, n_out2]).then(
|
275 |
+
tot_prob, None, fin, show_progress=False)
|
276 |
+
|
277 |
+
btn.click(image_classifier0, [inp], [n_out0]).then(
|
278 |
+
image_classifier1, [inp], [n_out1]).then(
|
279 |
+
image_classifier2, [inp], [n_out2]).then(
|
280 |
+
tot_prob, None, fin, show_progress=False)
|
281 |
+
|
282 |
+
# Tab for batch processing
|
283 |
+
with gr.Tab("Batch Image Processing"):
|
284 |
+
zip_file = gr.File(label="Upload Zip (two folders: real, ai)")
|
285 |
+
output_acc = gr.Label(label="Accuracy")
|
286 |
+
output_roc = gr.Label(label="ROC Score")
|
287 |
+
output_report = gr.Textbox(label="Classification Report", lines=10)
|
288 |
+
output_cm = gr.Plot(label="Confusion Matrix")
|
289 |
+
output_roc_plot = gr.Plot(label="ROC Curve")
|
290 |
+
|
291 |
+
batch_btn = gr.Button("Process Batch")
|
292 |
+
|
293 |
+
# Connect batch processing
|
294 |
+
batch_btn.click(process_zip, zip_file, [output_acc, output_roc, output_report, output_cm, output_roc_plot])
|
295 |
+
|
296 |
+
app.launch(show_api=False, max_threads=24)
|