import gradio as gr import torch from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline import os import zipfile import shutil import matplotlib.pyplot as plt from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report, roc_curve, auc, ConfusionMatrixDisplay from PIL import Image import tempfile import numpy as np import urllib.request import base64 from io import BytesIO from reportlab.lib.pagesizes import letter from reportlab.pdfgen import canvas from bs4 import BeautifulSoup #Update MODEL_NAME = "cmckinle/sdxl-flux-detector_v1.1" LABELS = ["AI", "Real"] class AIDetector: def __init__(self): self.pipe = pipeline("image-classification", MODEL_NAME) self.feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME) self.model = AutoModelForImageClassification.from_pretrained(MODEL_NAME) @staticmethod def softmax(vector): e = np.exp(vector - np.max(vector)) return e / e.sum() def predict(self, image): inputs = self.feature_extractor(image, return_tensors="pt") with torch.no_grad(): outputs = self.model(**inputs) logits = outputs.logits probabilities = self.softmax(logits.numpy()) prediction = logits.argmax(-1).item() label = LABELS[prediction] results = {label: float(prob) for label, prob in zip(LABELS, probabilities[0])} return label, results def process_zip(zip_file): temp_dir = tempfile.mkdtemp() try: with zipfile.ZipFile(zip_file.name, 'r') as z: file_list = z.namelist() if not ('real/' in file_list and 'ai/' in file_list): raise ValueError("Zip file must contain 'real' and 'ai' folders") z.extractall(temp_dir) return evaluate_model(temp_dir) except Exception as e: raise gr.Error(f"Error processing zip file: {str(e)}") finally: shutil.rmtree(temp_dir) def process_files(ai_files, real_files): temp_dir = tempfile.mkdtemp() try: ai_folder = os.path.join(temp_dir, 'ai') os.makedirs(ai_folder) for file in ai_files: shutil.copy(file.name, os.path.join(ai_folder, os.path.basename(file.name))) real_folder = os.path.join(temp_dir, 'real') os.makedirs(real_folder) for file in real_files: shutil.copy(file.name, os.path.join(real_folder, os.path.basename(file.name))) return evaluate_model(temp_dir) except Exception as e: raise gr.Error(f"Error processing individual files: {str(e)}") finally: shutil.rmtree(temp_dir) def evaluate_model(temp_dir): labels, preds, images = [], [], [] false_positives, false_negatives = [], [] detector = AIDetector() total_images = sum(len(files) for _, _, files in os.walk(temp_dir)) processed_images = 0 for folder_name, ground_truth_label in [('real', 1), ('ai', 0)]: folder_path = os.path.join(temp_dir, folder_name) if not os.path.exists(folder_path): raise ValueError(f"Folder not found: {folder_path}") for img_name in os.listdir(folder_path): img_path = os.path.join(folder_path, img_name) try: with Image.open(img_path).convert("RGB") as img: _, prediction = detector.predict(img) pred_label = 0 if prediction["AI"] > prediction["Real"] else 1 preds.append(pred_label) labels.append(ground_truth_label) images.append(img_name) if pred_label != ground_truth_label: with open(img_path, "rb") as img_file: img_data = base64.b64encode(img_file.read()).decode() if pred_label == 1 and ground_truth_label == 0: false_positives.append((img_name, img_data)) elif pred_label == 0 and ground_truth_label == 1: false_negatives.append((img_name, img_data)) except Exception as e: print(f"Error processing image {img_name}: {e}") processed_images += 1 gr.Progress(processed_images / total_images) return calculate_metrics(labels, preds, false_positives, false_negatives) def calculate_metrics(labels, preds, false_positives, false_negatives): cm = confusion_matrix(labels, preds) accuracy = accuracy_score(labels, preds) roc_score = roc_auc_score(labels, preds) report_html = format_classification_report(labels, preds) fpr, tpr, _ = roc_curve(labels, preds) roc_auc = auc(fpr, tpr) fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=LABELS).plot(cmap=plt.cm.Blues, ax=ax1) ax1.set_title("Confusion Matrix") ax2.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})') ax2.plot([0, 1], [0, 1], color='gray', linestyle='--') ax2.set_xlim([0.0, 1.0]) ax2.set_ylim([0.0, 1.05]) ax2.set_xlabel('False Positive Rate') ax2.set_ylabel('True Positive Rate') ax2.set_title('ROC Curve') ax2.legend(loc="lower right") plt.tight_layout() fp_fn_html = create_fp_fn_html(false_positives, false_negatives) return accuracy, roc_score, report_html, fig, fp_fn_html def format_classification_report(labels, preds): report_dict = classification_report(labels, preds, output_dict=True) html = """
Class | Precision | Recall | F1-Score | Support |
---|---|---|---|---|
{class_name} | {report_dict[class_name]['precision']:.2f} | {report_dict[class_name]['recall']:.2f} | {report_dict[class_name]['f1-score']:.2f} | {report_dict[class_name]['support']} |
Accuracy | {report_dict['accuracy']:.2f} | {report_dict['macro avg']['support']} |
{img_name}