Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,25 +1,226 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
<div class="image-item">
|
6 |
<img src="data:image/jpeg;base64,{img_data}" alt="{img_name}">
|
7 |
<p>{img_name}</p>
|
8 |
</div>
|
9 |
-
|
10 |
-
html += '</div>'
|
11 |
-
|
12 |
return html
|
13 |
|
14 |
-
def
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
detector = AIDetector()
|
25 |
|
@@ -29,75 +230,47 @@ def create_gradio_interface():
|
|
29 |
|
30 |
with gr.Tabs():
|
31 |
with gr.Tab("Single Image Detection"):
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
with gr.Group():
|
41 |
-
with gr.Box():
|
42 |
-
gr.HTML(f"""<b>Testing on Model: <a href='https://huggingface.co/{MODEL_NAME}'>{MODEL_NAME}</a></b>""")
|
43 |
-
output_html = gr.HTML()
|
44 |
-
output_label = gr.Label(label="Output")
|
45 |
|
46 |
with gr.Tab("Batch Image Processing"):
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
)
|
63 |
-
real_files = gr.File(
|
64 |
-
label="Upload Real Images",
|
65 |
-
file_types=["image"],
|
66 |
-
file_count="multiple"
|
67 |
-
)
|
68 |
-
individual_process_btn = gr.Button("Process Individual Files", interactive=False)
|
69 |
-
|
70 |
-
with gr.Group():
|
71 |
-
gr.Markdown(f"### Results for {MODEL_NAME}")
|
72 |
-
output_acc = gr.Label(label="Accuracy")
|
73 |
-
output_roc = gr.Label(label="ROC Score")
|
74 |
-
output_report = gr.HTML(label="Classification Report")
|
75 |
-
output_plots = gr.Plot(label="Confusion Matrix and ROC Curve")
|
76 |
-
output_fp_fn = gr.HTML(label="False Positives and Negatives")
|
77 |
-
|
78 |
-
# Add export button and PDF output
|
79 |
-
export_btn = gr.Button("Export Results to PDF", variant="primary")
|
80 |
-
pdf_output = gr.File(label="Downloaded PDF")
|
81 |
|
82 |
reset_btn = gr.Button("Reset")
|
83 |
|
84 |
load_btn.click(load_url, in_url, [inp, message])
|
85 |
-
btn.click(
|
86 |
-
lambda img: detector.predict(img),
|
87 |
-
inp,
|
88 |
-
[output_html, output_label]
|
89 |
-
)
|
90 |
-
|
91 |
-
def enable_zip_btn(file):
|
92 |
-
return gr.Button.update(interactive=file is not None)
|
93 |
|
94 |
-
def
|
95 |
-
|
|
|
|
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
101 |
|
102 |
zip_process_btn.click(
|
103 |
process_zip,
|
@@ -111,37 +284,8 @@ def create_gradio_interface():
|
|
111 |
[output_acc, output_roc, output_report, output_plots, output_fp_fn]
|
112 |
)
|
113 |
|
114 |
-
# Add export button click handler
|
115 |
-
export_btn.click(
|
116 |
-
export_to_pdf,
|
117 |
-
inputs=[output_acc, output_roc, output_report, output_plots, output_fp_fn],
|
118 |
-
outputs=pdf_output
|
119 |
-
)
|
120 |
-
|
121 |
-
def reset_interface():
|
122 |
-
return [
|
123 |
-
None, None, None, None, None, # Reset inputs
|
124 |
-
gr.Button.update(interactive=False), # Reset zip process button
|
125 |
-
gr.Button.update(interactive=False), # Reset individual process button
|
126 |
-
None, None, None, None, None, None # Reset outputs (including PDF)
|
127 |
-
]
|
128 |
-
|
129 |
-
reset_btn.click(
|
130 |
-
reset_interface,
|
131 |
-
inputs=None,
|
132 |
-
outputs=[
|
133 |
-
zip_file, ai_files, real_files,
|
134 |
-
output_acc, output_roc, output_report, output_plots, output_fp_fn,
|
135 |
-
zip_process_btn, individual_process_btn, pdf_output
|
136 |
-
]
|
137 |
-
)
|
138 |
-
|
139 |
return app
|
140 |
|
141 |
if __name__ == "__main__":
|
142 |
app = create_gradio_interface()
|
143 |
-
app.launch(
|
144 |
-
show_api=False,
|
145 |
-
max_threads=24,
|
146 |
-
show_error=True
|
147 |
-
)
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline
|
4 |
+
import os
|
5 |
+
import zipfile
|
6 |
+
import shutil
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report, roc_curve, auc, ConfusionMatrixDisplay
|
9 |
+
from PIL import Image
|
10 |
+
import tempfile
|
11 |
+
import numpy as np
|
12 |
+
import urllib.request
|
13 |
+
import base64
|
14 |
+
from io import BytesIO
|
15 |
+
from reportlab.lib.pagesizes import letter
|
16 |
+
from reportlab.pdfgen import canvas
|
17 |
+
|
18 |
+
MODEL_NAME = "cmckinle/sdxl-flux-detector"
|
19 |
+
LABELS = ["AI", "Real"]
|
20 |
+
|
21 |
+
class AIDetector:
|
22 |
+
def __init__(self):
|
23 |
+
self.pipe = pipeline("image-classification", MODEL_NAME)
|
24 |
+
self.feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
|
25 |
+
self.model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
|
26 |
+
|
27 |
+
@staticmethod
|
28 |
+
def softmax(vector):
|
29 |
+
e = np.exp(vector - np.max(vector))
|
30 |
+
return e / e.sum()
|
31 |
+
|
32 |
+
def predict(self, image):
|
33 |
+
inputs = self.feature_extractor(image, return_tensors="pt")
|
34 |
+
with torch.no_grad():
|
35 |
+
outputs = self.model(**inputs)
|
36 |
+
logits = outputs.logits
|
37 |
+
probabilities = self.softmax(logits.numpy())
|
38 |
+
|
39 |
+
prediction = logits.argmax(-1).item()
|
40 |
+
label = LABELS[prediction]
|
41 |
+
|
42 |
+
results = {label: float(prob) for label, prob in zip(LABELS, probabilities[0])}
|
43 |
+
|
44 |
+
return label, results
|
45 |
+
|
46 |
+
def process_zip(zip_file):
|
47 |
+
temp_dir = tempfile.mkdtemp()
|
48 |
+
|
49 |
+
try:
|
50 |
+
with zipfile.ZipFile(zip_file.name, 'r') as z:
|
51 |
+
file_list = z.namelist()
|
52 |
+
if not ('real/' in file_list and 'ai/' in file_list):
|
53 |
+
raise ValueError("Zip file must contain 'real' and 'ai' folders")
|
54 |
+
|
55 |
+
z.extractall(temp_dir)
|
56 |
+
|
57 |
+
return evaluate_model(temp_dir)
|
58 |
+
|
59 |
+
except Exception as e:
|
60 |
+
raise gr.Error(f"Error processing zip file: {str(e)}")
|
61 |
+
|
62 |
+
finally:
|
63 |
+
shutil.rmtree(temp_dir)
|
64 |
+
|
65 |
+
def process_files(ai_files, real_files):
|
66 |
+
temp_dir = tempfile.mkdtemp()
|
67 |
+
try:
|
68 |
+
ai_folder = os.path.join(temp_dir, 'ai')
|
69 |
+
os.makedirs(ai_folder)
|
70 |
+
for file in ai_files:
|
71 |
+
shutil.copy(file.name, os.path.join(ai_folder, os.path.basename(file.name)))
|
72 |
+
|
73 |
+
real_folder = os.path.join(temp_dir, 'real')
|
74 |
+
os.makedirs(real_folder)
|
75 |
+
for file in real_files:
|
76 |
+
shutil.copy(file.name, os.path.join(real_folder, os.path.basename(file.name)))
|
77 |
+
|
78 |
+
return evaluate_model(temp_dir)
|
79 |
+
except Exception as e:
|
80 |
+
raise gr.Error(f"Error processing individual files: {str(e)}")
|
81 |
+
finally:
|
82 |
+
shutil.rmtree(temp_dir)
|
83 |
+
|
84 |
+
def evaluate_model(temp_dir):
|
85 |
+
labels, preds, images = [], [], []
|
86 |
+
false_positives, false_negatives = [], []
|
87 |
+
detector = AIDetector()
|
88 |
+
|
89 |
+
total_images = sum(len(files) for _, _, files in os.walk(temp_dir))
|
90 |
+
processed_images = 0
|
91 |
+
|
92 |
+
for folder_name, ground_truth_label in [('real', 1), ('ai', 0)]:
|
93 |
+
folder_path = os.path.join(temp_dir, folder_name)
|
94 |
+
if not os.path.exists(folder_path):
|
95 |
+
raise ValueError(f"Folder not found: {folder_path}")
|
96 |
+
|
97 |
+
for img_name in os.listdir(folder_path):
|
98 |
+
img_path = os.path.join(folder_path, img_name)
|
99 |
+
try:
|
100 |
+
with Image.open(img_path).convert("RGB") as img:
|
101 |
+
_, prediction = detector.predict(img)
|
102 |
+
|
103 |
+
pred_label = 0 if prediction["AI"] > prediction["Real"] else 1
|
104 |
+
|
105 |
+
preds.append(pred_label)
|
106 |
+
labels.append(ground_truth_label)
|
107 |
+
images.append(img_name)
|
108 |
+
|
109 |
+
if pred_label != ground_truth_label:
|
110 |
+
with open(img_path, "rb") as img_file:
|
111 |
+
img_data = base64.b64encode(img_file.read()).decode()
|
112 |
+
if pred_label == 1 and ground_truth_label == 0:
|
113 |
+
false_positives.append((img_name, img_data))
|
114 |
+
elif pred_label == 0 and ground_truth_label == 1:
|
115 |
+
false_negatives.append((img_name, img_data))
|
116 |
+
|
117 |
+
except Exception as e:
|
118 |
+
print(f"Error processing image {img_name}: {e}")
|
119 |
+
|
120 |
+
processed_images += 1
|
121 |
+
gr.Progress(processed_images / total_images)
|
122 |
+
|
123 |
+
return calculate_metrics(labels, preds, false_positives, false_negatives)
|
124 |
+
|
125 |
+
def calculate_metrics(labels, preds, false_positives, false_negatives):
|
126 |
+
cm = confusion_matrix(labels, preds)
|
127 |
+
accuracy = accuracy_score(labels, preds)
|
128 |
+
roc_score = roc_auc_score(labels, preds)
|
129 |
+
report_html = format_classification_report(labels, preds)
|
130 |
+
fpr, tpr, _ = roc_curve(labels, preds)
|
131 |
+
roc_auc = auc(fpr, tpr)
|
132 |
+
|
133 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
|
134 |
+
|
135 |
+
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=LABELS).plot(cmap=plt.cm.Blues, ax=ax1)
|
136 |
+
ax1.set_title("Confusion Matrix")
|
137 |
+
|
138 |
+
ax2.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
|
139 |
+
ax2.plot([0, 1], [0, 1], color='gray', linestyle='--')
|
140 |
+
ax2.set_xlim([0.0, 1.0])
|
141 |
+
ax2.set_ylim([0.0, 1.05])
|
142 |
+
ax2.set_xlabel('False Positive Rate')
|
143 |
+
ax2.set_ylabel('True Positive Rate')
|
144 |
+
ax2.set_title('ROC Curve')
|
145 |
+
ax2.legend(loc="lower right")
|
146 |
+
|
147 |
+
plt.tight_layout()
|
148 |
+
|
149 |
+
fp_fn_html = create_fp_fn_html(false_positives, false_negatives)
|
150 |
+
|
151 |
+
return accuracy, roc_score, report_html, fig, fp_fn_html
|
152 |
+
|
153 |
+
def format_classification_report(labels, preds):
|
154 |
+
report_dict = classification_report(labels, preds, output_dict=True)
|
155 |
+
|
156 |
+
html = """
|
157 |
+
<table class="report-table">
|
158 |
+
<tr>
|
159 |
+
<th>Class</th>
|
160 |
+
<th>Precision</th>
|
161 |
+
<th>Recall</th>
|
162 |
+
<th>F1-Score</th>
|
163 |
+
<th>Support</th>
|
164 |
+
</tr>
|
165 |
+
"""
|
166 |
+
|
167 |
+
for class_name in ['0', '1']:
|
168 |
+
html += f"""
|
169 |
+
<tr>
|
170 |
+
<td>{class_name}</td>
|
171 |
+
<td>{report_dict[class_name]['precision']:.2f}</td>
|
172 |
+
<td>{report_dict[class_name]['recall']:.2f}</td>
|
173 |
+
<td>{report_dict[class_name]['f1-score']:.2f}</td>
|
174 |
+
<td>{report_dict[class_name]['support']}</td>
|
175 |
+
</tr>
|
176 |
+
"""
|
177 |
+
|
178 |
+
html += f"""
|
179 |
+
<tr>
|
180 |
+
<td>Accuracy</td>
|
181 |
+
<td colspan="3">{report_dict['accuracy']:.2f}</td>
|
182 |
+
<td>{report_dict['macro avg']['support']}</td>
|
183 |
+
</tr>
|
184 |
+
</table>
|
185 |
+
"""
|
186 |
+
|
187 |
+
return html
|
188 |
+
|
189 |
+
def create_fp_fn_html(false_positives, false_negatives):
|
190 |
+
html = "<div class='image-grid'>"
|
191 |
+
for img_name, img_data in false_positives + false_negatives:
|
192 |
+
html += f"""
|
193 |
<div class="image-item">
|
194 |
<img src="data:image/jpeg;base64,{img_data}" alt="{img_name}">
|
195 |
<p>{img_name}</p>
|
196 |
</div>
|
197 |
+
"""
|
|
|
|
|
198 |
return html
|
199 |
|
200 |
+
def generate_pdf(accuracy, roc_score, report_html, confusion_matrix_plot):
|
201 |
+
buffer = BytesIO()
|
202 |
+
c = canvas.Canvas(buffer, pagesize=letter)
|
203 |
+
|
204 |
+
c.drawString(100, 750, f"Model Results")
|
205 |
+
c.drawString(100, 730, f"Accuracy: {accuracy:.2f}")
|
206 |
+
c.drawString(100, 710, f"ROC Score: {roc_score:.2f}")
|
207 |
+
|
208 |
+
y_position = 690
|
209 |
+
for line in report_html.split('<tr>')[2:]:
|
210 |
+
if y_position < 50:
|
211 |
+
c.showPage()
|
212 |
+
y_position = 750
|
213 |
+
c.drawString(100, y_position, line.strip())
|
214 |
+
y_position -= 20
|
215 |
+
|
216 |
+
img_buffer = BytesIO()
|
217 |
+
confusion_matrix_plot.savefig(img_buffer, format="png")
|
218 |
+
img_buffer.seek(0)
|
219 |
+
c.drawImage(img_buffer, 100, y_position - 250, width=400, height=300)
|
220 |
+
|
221 |
+
c.save()
|
222 |
+
buffer.seek(0)
|
223 |
+
return buffer
|
224 |
|
225 |
detector = AIDetector()
|
226 |
|
|
|
230 |
|
231 |
with gr.Tabs():
|
232 |
with gr.Tab("Single Image Detection"):
|
233 |
+
inp = gr.Image(type='pil')
|
234 |
+
in_url = gr.Textbox(label="Image URL")
|
235 |
+
load_btn = gr.Button("Load URL")
|
236 |
+
btn = gr.Button("Detect AI")
|
237 |
+
message = gr.HTML()
|
238 |
+
|
239 |
+
output_html = gr.HTML()
|
240 |
+
output_label = gr.Label(label="Output")
|
|
|
|
|
|
|
|
|
|
|
241 |
|
242 |
with gr.Tab("Batch Image Processing"):
|
243 |
+
zip_file = gr.File(label="Upload Zip", file_types=[".zip"], file_count="single")
|
244 |
+
zip_process_btn = gr.Button("Process Zip")
|
245 |
+
|
246 |
+
ai_files = gr.File(label="Upload AI Images", file_types=["image"], file_count="multiple")
|
247 |
+
real_files = gr.File(label="Upload Real Images", file_types=["image"], file_count="multiple")
|
248 |
+
individual_process_btn = gr.Button("Process Individual Files")
|
249 |
+
|
250 |
+
output_acc = gr.Label(label="Accuracy")
|
251 |
+
output_roc = gr.Label(label="ROC Score")
|
252 |
+
output_report = gr.HTML(label="Classification Report")
|
253 |
+
output_plots = gr.Plot(label="Confusion Matrix and ROC Curve")
|
254 |
+
output_fp_fn = gr.HTML(label="False Positives and Negatives")
|
255 |
+
|
256 |
+
download_pdf_btn = gr.Button("Download Results as PDF")
|
257 |
+
pdf_output = gr.File(label="Download PDF", visible=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
|
259 |
reset_btn = gr.Button("Reset")
|
260 |
|
261 |
load_btn.click(load_url, in_url, [inp, message])
|
262 |
+
btn.click(lambda img: detector.predict(img), inp, [output_html, output_label])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
|
264 |
+
def on_download_pdf(accuracy, roc_score, report_html, confusion_matrix_plot):
|
265 |
+
pdf_buffer = generate_pdf(accuracy, roc_score, report_html, confusion_matrix_plot)
|
266 |
+
pdf_buffer.seek(0)
|
267 |
+
return pdf_buffer
|
268 |
|
269 |
+
download_pdf_btn.click(
|
270 |
+
on_download_pdf,
|
271 |
+
inputs=[output_acc, output_roc, output_report, output_plots],
|
272 |
+
outputs=pdf_output
|
273 |
+
)
|
274 |
|
275 |
zip_process_btn.click(
|
276 |
process_zip,
|
|
|
284 |
[output_acc, output_roc, output_report, output_plots, output_fp_fn]
|
285 |
)
|
286 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
return app
|
288 |
|
289 |
if __name__ == "__main__":
|
290 |
app = create_gradio_interface()
|
291 |
+
app.launch(show_api=False, max_threads=24, show_error=True)
|
|
|
|
|
|
|
|