ImageDetector / app.py
cmckinle's picture
Update app.py
cbc1123 verified
raw
history blame
12.8 kB
import gradio as gr
import torch
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline
import os
import zipfile
import shutil
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report, roc_curve, auc, ConfusionMatrixDisplay
from PIL import Image
import tempfile
import numpy as np
import urllib.request
import base64
from io import BytesIO
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
MODEL_NAME = "cmckinle/sdxl-flux-detector"
LABELS = ["AI", "Real"]
class AIDetector:
def __init__(self):
self.pipe = pipeline("image-classification", MODEL_NAME)
self.feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
self.model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
@staticmethod
def softmax(vector):
e = np.exp(vector - np.max(vector))
return e / e.sum()
def predict(self, image):
inputs = self.feature_extractor(image, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
probabilities = self.softmax(logits.numpy())
prediction = logits.argmax(-1).item()
label = LABELS[prediction]
results = {label: float(prob) for label, prob in zip(LABELS, probabilities[0])}
return label, results
def process_zip(zip_file):
temp_dir = tempfile.mkdtemp()
try:
with zipfile.ZipFile(zip_file.name, 'r') as z:
file_list = z.namelist()
if not ('real/' in file_list and 'ai/' in file_list):
raise ValueError("Zip file must contain 'real' and 'ai' folders")
z.extractall(temp_dir)
return evaluate_model(temp_dir)
except Exception as e:
raise gr.Error(f"Error processing zip file: {str(e)}")
finally:
shutil.rmtree(temp_dir)
def process_files(ai_files, real_files):
temp_dir = tempfile.mkdtemp()
try:
ai_folder = os.path.join(temp_dir, 'ai')
os.makedirs(ai_folder)
for file in ai_files:
shutil.copy(file.name, os.path.join(ai_folder, os.path.basename(file.name)))
real_folder = os.path.join(temp_dir, 'real')
os.makedirs(real_folder)
for file in real_files:
shutil.copy(file.name, os.path.join(real_folder, os.path.basename(file.name)))
return evaluate_model(temp_dir)
except Exception as e:
raise gr.Error(f"Error processing individual files: {str(e)}")
finally:
shutil.rmtree(temp_dir)
def evaluate_model(temp_dir):
labels, preds, images = [], [], []
false_positives, false_negatives = [], []
detector = AIDetector()
total_images = sum(len(files) for _, _, files in os.walk(temp_dir))
processed_images = 0
for folder_name, ground_truth_label in [('real', 1), ('ai', 0)]:
folder_path = os.path.join(temp_dir, folder_name)
if not os.path.exists(folder_path):
raise ValueError(f"Folder not found: {folder_path}")
for img_name in os.listdir(folder_path):
img_path = os.path.join(folder_path, img_name)
try:
with Image.open(img_path).convert("RGB") as img:
_, prediction = detector.predict(img)
pred_label = 0 if prediction["AI"] > prediction["Real"] else 1
preds.append(pred_label)
labels.append(ground_truth_label)
images.append(img_name)
if pred_label != ground_truth_label:
with open(img_path, "rb") as img_file:
img_data = base64.b64encode(img_file.read()).decode()
if pred_label == 1 and ground_truth_label == 0:
false_positives.append((img_name, img_data))
elif pred_label == 0 and ground_truth_label == 1:
false_negatives.append((img_name, img_data))
except Exception as e:
print(f"Error processing image {img_name}: {e}")
processed_images += 1
gr.Progress(processed_images / total_images)
return calculate_metrics(labels, preds, false_positives, false_negatives)
def calculate_metrics(labels, preds, false_positives, false_negatives):
cm = confusion_matrix(labels, preds)
accuracy = accuracy_score(labels, preds)
roc_score = roc_auc_score(labels, preds)
report_html = format_classification_report(labels, preds)
fpr, tpr, _ = roc_curve(labels, preds)
roc_auc = auc(fpr, tpr)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=LABELS).plot(cmap=plt.cm.Blues, ax=ax1)
ax1.set_title("Confusion Matrix")
ax2.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
ax2.plot([0, 1], [0, 1], color='gray', linestyle='--')
ax2.set_xlim([0.0, 1.0])
ax2.set_ylim([0.0, 1.05])
ax2.set_xlabel('False Positive Rate')
ax2.set_ylabel('True Positive Rate')
ax2.set_title('ROC Curve')
ax2.legend(loc="lower right")
plt.tight_layout()
fp_fn_html = create_fp_fn_html(false_positives, false_negatives)
return accuracy, roc_score, report_html, fig, fp_fn_html
def format_classification_report(labels, preds):
report_dict = classification_report(labels, preds, output_dict=True)
html = """
<table class="report-table">
<tr>
<th>Class</th>
<th>Precision</th>
<th>Recall</th>
<th>F1-Score</th>
<th>Support</th>
</tr>
"""
for class_name in ['0', '1']:
html += f"""
<tr>
<td>{class_name}</td>
<td>{report_dict[class_name]['precision']:.2f}</td>
<td>{report_dict[class_name]['recall']:.2f}</td>
<td>{report_dict[class_name]['f1-score']:.2f}</td>
<td>{report_dict[class_name]['support']}</td>
</tr>
"""
html += f"""
<tr>
<td>Accuracy</td>
<td colspan="3">{report_dict['accuracy']:.2f}</td>
<td>{report_dict['macro avg']['support']}</td>
</tr>
</table>
"""
return html
def create_fp_fn_html(false_positives, false_negatives):
html = "<div class='image-grid'>"
for img_name, img_data in false_positives + false_negatives:
html += f"""
<div class="image-item">
<img src="data:image/jpeg;base64,{img_data}" alt="{img_name}">
<p>{img_name}</p>
</div>
"""
return html
def generate_pdf(accuracy, roc_score, report_html, confusion_matrix_plot):
buffer = BytesIO()
c = canvas.Canvas(buffer, pagesize=letter)
# Add content to PDF
c.drawString(100, 750, f"Model Results")
c.drawString(100, 730, f"Accuracy: {accuracy:.2f}")
c.drawString(100, 710, f"ROC Score: {roc_score:.2f}")
y_position = 690
for line in report_html.replace("<br>", "\n").splitlines():
if y_position < 50:
c.showPage()
y_position = 750
c.drawString(100, y_position, line.strip())
y_position -= 20
# Save Confusion Matrix Plot as an Image and Add it to the PDF
img_buffer = BytesIO()
confusion_matrix_plot.savefig(img_buffer, format="png")
img_buffer.seek(0)
c.drawImage(img_buffer, 100, y_position - 250, width=400, height=300)
c.save()
buffer.seek(0)
return buffer
detector = AIDetector()
def create_gradio_interface():
with gr.Blocks() as app:
gr.Markdown("""<center><h1>AI Image Detector</h1></center>""")
with gr.Tabs():
with gr.Tab("Single Image Detection"):
with gr.Column():
inp = gr.Image(type='pil')
in_url = gr.Textbox(label="Image URL")
with gr.Row():
load_btn = gr.Button("Load URL")
btn = gr.Button("Detect AI")
message = gr.HTML()
with gr.Group():
with gr.Box():
gr.HTML(f"""<b>Testing on Model: <a href='https://huggingface.co/{MODEL_NAME}'>{MODEL_NAME}</a></b>""")
output_html = gr.HTML()
output_label = gr.Label(label="Output")
with gr.Tab("Batch Image Processing"):
with gr.Accordion("Upload Zip File (max 100MB)", open=False):
zip_file = gr.File(
label="Upload Zip (must contain 'real' and 'ai' folders)",
file_types=[".zip"],
file_count="single",
max_file_size=100 # 100 MB limit
)
zip_process_btn = gr.Button("Process Zip", interactive=False)
with gr.Accordion("Upload Individual Files (for datasets over 100MB)", open=False):
with gr.Row():
ai_files = gr.File(
label="Upload AI Images",
file_types=["image"],
file_count="multiple"
)
real_files = gr.File(
label="Upload Real Images",
file_types=["image"],
file_count="multiple"
)
individual_process_btn = gr.Button("Process Individual Files", interactive=False)
with gr.Group():
gr.Markdown(f"### Results for {MODEL_NAME}")
output_acc = gr.Label(label="Accuracy")
output_roc = gr.Label(label="ROC Score")
output_report = gr.HTML(label="Classification Report")
output_plots = gr.Plot(label="Confusion Matrix and ROC Curve")
output_fp_fn = gr.HTML(label="False Positives and Negatives")
download_pdf_btn = gr.Button("Download Results as PDF")
pdf_output = gr.File(label="Download PDF", visible=False)
reset_btn = gr.Button("Reset")
load_btn.click(load_url, in_url, [inp, message])
btn.click(
lambda img: detector.predict(img),
inp,
[output_html, output_label]
)
def enable_zip_btn(file):
return gr.Button.update(interactive=file is not None)
def enable_individual_btn(ai_files, real_files):
return gr.Button.update(interactive=(ai_files is not None and real_files is not None))
zip_file.upload(enable_zip_btn, zip_file, zip_process_btn)
ai_files.upload(enable_individual_btn, [ai_files, real_files], individual_process_btn)
real_files.upload(enable_individual_btn, [ai_files, real_files], individual_process_btn)
zip_process_btn.click(
process_zip,
zip_file,
[output_acc, output_roc, output_report, output_plots, output_fp_fn]
)
individual_process_btn.click(
process_files,
[ai_files, real_files],
[output_acc, output_roc, output_report, output_plots, output_fp_fn]
)
def on_download_pdf(accuracy, roc_score, report_html, confusion_matrix_plot):
pdf_buffer = generate_pdf(accuracy, roc_score, report_html, confusion_matrix_plot)
pdf_buffer.seek(0)
return pdf_buffer
download_pdf_btn.click(
on_download_pdf,
inputs=[output_acc, output_roc, output_report, output_plots],
outputs=pdf_output
)
def reset_interface():
return [
None, None, None, None, None, # Reset inputs
gr.Button.update(interactive=False), # Reset zip process button
gr.Button.update(interactive=False), # Reset individual process button
None, None, None, None, None # Reset outputs
]
reset_btn.click(
reset_interface,
inputs=None,
outputs=[
zip_file, ai_files, real_files,
output_acc, output_roc, output_report, output_plots, output_fp_fn,
zip_process_btn, individual_process_btn
]
)
return app
if __name__ == "__main__":
app = create_gradio_interface()
app.launch(
show_api=False,
max_threads=24,
show_error=True
)