Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -53,46 +53,7 @@ def process_zip(zip_file):
|
|
53 |
|
54 |
z.extractall(temp_dir)
|
55 |
|
56 |
-
|
57 |
-
false_positives, false_negatives = [], []
|
58 |
-
detector = AIDetector()
|
59 |
-
|
60 |
-
total_images = sum(len(files) for _, _, files in os.walk(temp_dir))
|
61 |
-
processed_images = 0
|
62 |
-
|
63 |
-
for folder_name, ground_truth_label in [('real', 1), ('ai', 0)]:
|
64 |
-
folder_path = os.path.join(temp_dir, folder_name)
|
65 |
-
if not os.path.exists(folder_path):
|
66 |
-
raise ValueError(f"Folder not found: {folder_path}")
|
67 |
-
|
68 |
-
for img_name in os.listdir(folder_path):
|
69 |
-
img_path = os.path.join(folder_path, img_name)
|
70 |
-
try:
|
71 |
-
with Image.open(img_path).convert("RGB") as img:
|
72 |
-
_, prediction = detector.predict(img)
|
73 |
-
|
74 |
-
pred_label = 0 if prediction["AI"] > prediction["Real"] else 1
|
75 |
-
|
76 |
-
preds.append(pred_label)
|
77 |
-
labels.append(ground_truth_label)
|
78 |
-
images.append(img_name)
|
79 |
-
|
80 |
-
# Collect false positives and false negatives with image data
|
81 |
-
if pred_label != ground_truth_label:
|
82 |
-
with open(img_path, "rb") as img_file:
|
83 |
-
img_data = base64.b64encode(img_file.read()).decode()
|
84 |
-
if pred_label == 1 and ground_truth_label == 0:
|
85 |
-
false_positives.append((img_name, img_data))
|
86 |
-
elif pred_label == 0 and ground_truth_label == 1:
|
87 |
-
false_negatives.append((img_name, img_data))
|
88 |
-
|
89 |
-
except Exception as e:
|
90 |
-
print(f"Error processing image {img_name}: {e}")
|
91 |
-
|
92 |
-
processed_images += 1
|
93 |
-
gr.Progress(processed_images / total_images)
|
94 |
-
|
95 |
-
return evaluate_model(labels, preds, false_positives, false_negatives)
|
96 |
|
97 |
except Exception as e:
|
98 |
raise gr.Error(f"Error processing zip file: {str(e)}")
|
@@ -100,11 +61,100 @@ def process_zip(zip_file):
|
|
100 |
finally:
|
101 |
shutil.rmtree(temp_dir)
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
def format_classification_report(labels, preds):
|
104 |
-
# Convert the report string to a dictionary
|
105 |
report_dict = classification_report(labels, preds, output_dict=True)
|
106 |
|
107 |
-
# Create an HTML table with updated CSS
|
108 |
html = """
|
109 |
<style>
|
110 |
.report-table {
|
@@ -170,7 +220,6 @@ def format_classification_report(labels, preds):
|
|
170 |
</tr>
|
171 |
"""
|
172 |
|
173 |
-
# Add rows for each class
|
174 |
for class_name in ['0', '1']:
|
175 |
html += f"""
|
176 |
<tr>
|
@@ -182,7 +231,6 @@ def format_classification_report(labels, preds):
|
|
182 |
</tr>
|
183 |
"""
|
184 |
|
185 |
-
# Add summary rows
|
186 |
html += f"""
|
187 |
<tr>
|
188 |
<td>Accuracy</td>
|
@@ -207,33 +255,9 @@ def format_classification_report(labels, preds):
|
|
207 |
"""
|
208 |
|
209 |
return html
|
210 |
-
|
211 |
-
def evaluate_model(labels, preds, false_positives, false_negatives):
|
212 |
-
cm = confusion_matrix(labels, preds)
|
213 |
-
accuracy = accuracy_score(labels, preds)
|
214 |
-
roc_score = roc_auc_score(labels, preds)
|
215 |
-
report_html = format_classification_report(labels, preds)
|
216 |
-
fpr, tpr, _ = roc_curve(labels, preds)
|
217 |
-
roc_auc = auc(fpr, tpr)
|
218 |
-
|
219 |
-
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
|
220 |
-
|
221 |
-
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=LABELS).plot(cmap=plt.cm.Blues, ax=ax1)
|
222 |
-
ax1.set_title("Confusion Matrix")
|
223 |
-
|
224 |
-
ax2.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
|
225 |
-
ax2.plot([0, 1], [0, 1], color='gray', linestyle='--')
|
226 |
-
ax2.set_xlim([0.0, 1.0])
|
227 |
-
ax2.set_ylim([0.0, 1.05])
|
228 |
-
ax2.set_xlabel('False Positive Rate')
|
229 |
-
ax2.set_ylabel('True Positive Rate')
|
230 |
-
ax2.set_title('ROC Curve')
|
231 |
-
ax2.legend(loc="lower right")
|
232 |
-
|
233 |
-
plt.tight_layout()
|
234 |
|
235 |
-
|
236 |
-
|
237 |
<style>
|
238 |
.image-grid {
|
239 |
display: flex;
|
@@ -252,29 +276,29 @@ def evaluate_model(labels, preds, false_positives, false_negatives):
|
|
252 |
</style>
|
253 |
"""
|
254 |
|
255 |
-
|
256 |
-
|
257 |
for img_name, img_data in false_positives:
|
258 |
-
|
259 |
<div class="image-item">
|
260 |
<img src="data:image/jpeg;base64,{img_data}" alt="{img_name}">
|
261 |
<p>{img_name}</p>
|
262 |
</div>
|
263 |
'''
|
264 |
-
|
265 |
|
266 |
-
|
267 |
-
|
268 |
for img_name, img_data in false_negatives:
|
269 |
-
|
270 |
<div class="image-item">
|
271 |
<img src="data:image/jpeg;base64,{img_data}" alt="{img_name}">
|
272 |
<p>{img_name}</p>
|
273 |
</div>
|
274 |
'''
|
275 |
-
|
276 |
|
277 |
-
return
|
278 |
|
279 |
def load_url(url):
|
280 |
try:
|
@@ -309,13 +333,28 @@ def create_gradio_interface():
|
|
309 |
output_label = gr.Label(label="Output")
|
310 |
|
311 |
with gr.Tab("Batch Image Processing"):
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
|
320 |
with gr.Group():
|
321 |
gr.Markdown(f"### Results for {MODEL_NAME}")
|
@@ -332,28 +371,7 @@ def create_gradio_interface():
|
|
332 |
[output_html, output_label]
|
333 |
)
|
334 |
|
335 |
-
def
|
336 |
return gr.Button.update(interactive=file is not None)
|
337 |
|
338 |
-
|
339 |
-
enable_batch_btn,
|
340 |
-
zip_file,
|
341 |
-
batch_btn
|
342 |
-
)
|
343 |
-
|
344 |
-
batch_btn.click(
|
345 |
-
process_zip,
|
346 |
-
zip_file,
|
347 |
-
[output_acc, output_roc, output_report, output_plots, output_fp_fn],
|
348 |
-
api_name="batch_process"
|
349 |
-
)
|
350 |
-
|
351 |
-
return app
|
352 |
-
|
353 |
-
if __name__ == "__main__":
|
354 |
-
app = create_gradio_interface()
|
355 |
-
app.launch(
|
356 |
-
show_api=False,
|
357 |
-
max_threads=24,
|
358 |
-
show_error=True
|
359 |
-
)
|
|
|
53 |
|
54 |
z.extractall(temp_dir)
|
55 |
|
56 |
+
return evaluate_model(temp_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
except Exception as e:
|
59 |
raise gr.Error(f"Error processing zip file: {str(e)}")
|
|
|
61 |
finally:
|
62 |
shutil.rmtree(temp_dir)
|
63 |
|
64 |
+
def process_files(ai_files, real_files):
|
65 |
+
temp_dir = tempfile.mkdtemp()
|
66 |
+
try:
|
67 |
+
# Process AI files
|
68 |
+
ai_folder = os.path.join(temp_dir, 'ai')
|
69 |
+
os.makedirs(ai_folder)
|
70 |
+
for file in ai_files:
|
71 |
+
shutil.copy(file.name, os.path.join(ai_folder, os.path.basename(file.name)))
|
72 |
+
|
73 |
+
# Process Real files
|
74 |
+
real_folder = os.path.join(temp_dir, 'real')
|
75 |
+
os.makedirs(real_folder)
|
76 |
+
for file in real_files:
|
77 |
+
shutil.copy(file.name, os.path.join(real_folder, os.path.basename(file.name)))
|
78 |
+
|
79 |
+
return evaluate_model(temp_dir)
|
80 |
+
except Exception as e:
|
81 |
+
raise gr.Error(f"Error processing individual files: {str(e)}")
|
82 |
+
finally:
|
83 |
+
shutil.rmtree(temp_dir)
|
84 |
+
|
85 |
+
def evaluate_model(temp_dir):
|
86 |
+
labels, preds, images = [], [], []
|
87 |
+
false_positives, false_negatives = [], []
|
88 |
+
detector = AIDetector()
|
89 |
+
|
90 |
+
total_images = sum(len(files) for _, _, files in os.walk(temp_dir))
|
91 |
+
processed_images = 0
|
92 |
+
|
93 |
+
for folder_name, ground_truth_label in [('real', 1), ('ai', 0)]:
|
94 |
+
folder_path = os.path.join(temp_dir, folder_name)
|
95 |
+
if not os.path.exists(folder_path):
|
96 |
+
raise ValueError(f"Folder not found: {folder_path}")
|
97 |
+
|
98 |
+
for img_name in os.listdir(folder_path):
|
99 |
+
img_path = os.path.join(folder_path, img_name)
|
100 |
+
try:
|
101 |
+
with Image.open(img_path).convert("RGB") as img:
|
102 |
+
_, prediction = detector.predict(img)
|
103 |
+
|
104 |
+
pred_label = 0 if prediction["AI"] > prediction["Real"] else 1
|
105 |
+
|
106 |
+
preds.append(pred_label)
|
107 |
+
labels.append(ground_truth_label)
|
108 |
+
images.append(img_name)
|
109 |
+
|
110 |
+
# Collect false positives and false negatives with image data
|
111 |
+
if pred_label != ground_truth_label:
|
112 |
+
with open(img_path, "rb") as img_file:
|
113 |
+
img_data = base64.b64encode(img_file.read()).decode()
|
114 |
+
if pred_label == 1 and ground_truth_label == 0:
|
115 |
+
false_positives.append((img_name, img_data))
|
116 |
+
elif pred_label == 0 and ground_truth_label == 1:
|
117 |
+
false_negatives.append((img_name, img_data))
|
118 |
+
|
119 |
+
except Exception as e:
|
120 |
+
print(f"Error processing image {img_name}: {e}")
|
121 |
+
|
122 |
+
processed_images += 1
|
123 |
+
gr.Progress(processed_images / total_images)
|
124 |
+
|
125 |
+
return calculate_metrics(labels, preds, false_positives, false_negatives)
|
126 |
+
|
127 |
+
def calculate_metrics(labels, preds, false_positives, false_negatives):
|
128 |
+
cm = confusion_matrix(labels, preds)
|
129 |
+
accuracy = accuracy_score(labels, preds)
|
130 |
+
roc_score = roc_auc_score(labels, preds)
|
131 |
+
report_html = format_classification_report(labels, preds)
|
132 |
+
fpr, tpr, _ = roc_curve(labels, preds)
|
133 |
+
roc_auc = auc(fpr, tpr)
|
134 |
+
|
135 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
|
136 |
+
|
137 |
+
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=LABELS).plot(cmap=plt.cm.Blues, ax=ax1)
|
138 |
+
ax1.set_title("Confusion Matrix")
|
139 |
+
|
140 |
+
ax2.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
|
141 |
+
ax2.plot([0, 1], [0, 1], color='gray', linestyle='--')
|
142 |
+
ax2.set_xlim([0.0, 1.0])
|
143 |
+
ax2.set_ylim([0.0, 1.05])
|
144 |
+
ax2.set_xlabel('False Positive Rate')
|
145 |
+
ax2.set_ylabel('True Positive Rate')
|
146 |
+
ax2.set_title('ROC Curve')
|
147 |
+
ax2.legend(loc="lower right")
|
148 |
+
|
149 |
+
plt.tight_layout()
|
150 |
+
|
151 |
+
fp_fn_html = create_fp_fn_html(false_positives, false_negatives)
|
152 |
+
|
153 |
+
return accuracy, roc_score, report_html, fig, fp_fn_html
|
154 |
+
|
155 |
def format_classification_report(labels, preds):
|
|
|
156 |
report_dict = classification_report(labels, preds, output_dict=True)
|
157 |
|
|
|
158 |
html = """
|
159 |
<style>
|
160 |
.report-table {
|
|
|
220 |
</tr>
|
221 |
"""
|
222 |
|
|
|
223 |
for class_name in ['0', '1']:
|
224 |
html += f"""
|
225 |
<tr>
|
|
|
231 |
</tr>
|
232 |
"""
|
233 |
|
|
|
234 |
html += f"""
|
235 |
<tr>
|
236 |
<td>Accuracy</td>
|
|
|
255 |
"""
|
256 |
|
257 |
return html
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
|
259 |
+
def create_fp_fn_html(false_positives, false_negatives):
|
260 |
+
html = """
|
261 |
<style>
|
262 |
.image-grid {
|
263 |
display: flex;
|
|
|
276 |
</style>
|
277 |
"""
|
278 |
|
279 |
+
html += "<h3>False Positives (AI images classified as Real):</h3>"
|
280 |
+
html += '<div class="image-grid">'
|
281 |
for img_name, img_data in false_positives:
|
282 |
+
html += f'''
|
283 |
<div class="image-item">
|
284 |
<img src="data:image/jpeg;base64,{img_data}" alt="{img_name}">
|
285 |
<p>{img_name}</p>
|
286 |
</div>
|
287 |
'''
|
288 |
+
html += '</div>'
|
289 |
|
290 |
+
html += "<h3>False Negatives (Real images classified as AI):</h3>"
|
291 |
+
html += '<div class="image-grid">'
|
292 |
for img_name, img_data in false_negatives:
|
293 |
+
html += f'''
|
294 |
<div class="image-item">
|
295 |
<img src="data:image/jpeg;base64,{img_data}" alt="{img_name}">
|
296 |
<p>{img_name}</p>
|
297 |
</div>
|
298 |
'''
|
299 |
+
html += '</div>'
|
300 |
|
301 |
+
return html
|
302 |
|
303 |
def load_url(url):
|
304 |
try:
|
|
|
333 |
output_label = gr.Label(label="Output")
|
334 |
|
335 |
with gr.Tab("Batch Image Processing"):
|
336 |
+
with gr.Accordion("Upload Zip File (max 100MB)", open=False):
|
337 |
+
zip_file = gr.File(
|
338 |
+
label="Upload Zip (must contain 'real' and 'ai' folders)",
|
339 |
+
file_types=[".zip"],
|
340 |
+
file_count="single",
|
341 |
+
max_file_size=100 # 100 MB limit
|
342 |
+
)
|
343 |
+
zip_process_btn = gr.Button("Process Zip", interactive=False)
|
344 |
+
|
345 |
+
with gr.Accordion("Upload Individual Files (for datasets over 100MB)", open=False):
|
346 |
+
with gr.Row():
|
347 |
+
ai_files = gr.File(
|
348 |
+
label="Upload AI Images",
|
349 |
+
file_types=["image"],
|
350 |
+
file_count="multiple"
|
351 |
+
)
|
352 |
+
real_files = gr.File(
|
353 |
+
label="Upload Real Images",
|
354 |
+
file_types=["image"],
|
355 |
+
file_count="multiple"
|
356 |
+
)
|
357 |
+
individual_process_btn = gr.Button("Process Individual Files", interactive=False)
|
358 |
|
359 |
with gr.Group():
|
360 |
gr.Markdown(f"### Results for {MODEL_NAME}")
|
|
|
371 |
[output_html, output_label]
|
372 |
)
|
373 |
|
374 |
+
def enable_zip_btn(file):
|
375 |
return gr.Button.update(interactive=file is not None)
|
376 |
|
377 |
+
def enable_individual_btn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|