Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline | |
import os | |
import zipfile | |
import shutil | |
import matplotlib.pyplot as plt | |
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report, roc_curve, auc, ConfusionMatrixDisplay | |
from PIL import Image | |
import tempfile | |
import numpy as np | |
import urllib.request | |
import base64 | |
from io import BytesIO | |
from reportlab.lib.pagesizes import letter | |
from reportlab.pdfgen import canvas | |
from bs4 import BeautifulSoup | |
MODEL_NAME = "cmckinle/sdxl-flux-detector_v1.1" | |
LABELS = ["AI", "Real"] | |
class AIDetector: | |
def __init__(self): | |
self.pipe = pipeline("image-classification", MODEL_NAME) | |
self.feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME) | |
self.model = AutoModelForImageClassification.from_pretrained(MODEL_NAME) | |
def softmax(vector): | |
e = np.exp(vector - np.max(vector)) | |
return e / e.sum() | |
def predict(self, image): | |
inputs = self.feature_extractor(image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
logits = outputs.logits | |
probabilities = self.softmax(logits.numpy()) | |
prediction = logits.argmax(-1).item() | |
label = LABELS[prediction] | |
results = {label: float(prob) for label, prob in zip(LABELS, probabilities[0])} | |
return label, results | |
def process_zip(zip_file): | |
temp_dir = tempfile.mkdtemp() | |
try: | |
with zipfile.ZipFile(zip_file.name, 'r') as z: | |
file_list = z.namelist() | |
if not ('real/' in file_list and 'ai/' in file_list): | |
raise ValueError("Zip file must contain 'real' and 'ai' folders") | |
z.extractall(temp_dir) | |
return evaluate_model(temp_dir) | |
except Exception as e: | |
raise gr.Error(f"Error processing zip file: {str(e)}") | |
finally: | |
shutil.rmtree(temp_dir) | |
def process_files(ai_files, real_files): | |
temp_dir = tempfile.mkdtemp() | |
try: | |
ai_folder = os.path.join(temp_dir, 'ai') | |
os.makedirs(ai_folder) | |
for file in ai_files: | |
shutil.copy(file.name, os.path.join(ai_folder, os.path.basename(file.name))) | |
real_folder = os.path.join(temp_dir, 'real') | |
os.makedirs(real_folder) | |
for file in real_files: | |
shutil.copy(file.name, os.path.join(real_folder, os.path.basename(file.name))) | |
return evaluate_model(temp_dir) | |
except Exception as e: | |
raise gr.Error(f"Error processing individual files: {str(e)}") | |
finally: | |
shutil.rmtree(temp_dir) | |
def evaluate_model(temp_dir): | |
labels, preds, images = [], [], [] | |
false_positives, false_negatives = [], [] | |
detector = AIDetector() | |
total_images = sum(len(files) for _, _, files in os.walk(temp_dir)) | |
processed_images = 0 | |
for folder_name, ground_truth_label in [('real', 1), ('ai', 0)]: | |
folder_path = os.path.join(temp_dir, folder_name) | |
if not os.path.exists(folder_path): | |
raise ValueError(f"Folder not found: {folder_path}") | |
for img_name in os.listdir(folder_path): | |
img_path = os.path.join(folder_path, img_name) | |
try: | |
with Image.open(img_path).convert("RGB") as img: | |
_, prediction = detector.predict(img) | |
pred_label = 0 if prediction["AI"] > prediction["Real"] else 1 | |
preds.append(pred_label) | |
labels.append(ground_truth_label) | |
images.append(img_name) | |
if pred_label != ground_truth_label: | |
with open(img_path, "rb") as img_file: | |
img_data = base64.b64encode(img_file.read()).decode() | |
if pred_label == 1 and ground_truth_label == 0: | |
false_positives.append((img_name, img_data)) | |
elif pred_label == 0 and ground_truth_label == 1: | |
false_negatives.append((img_name, img_data)) | |
except Exception as e: | |
print(f"Error processing image {img_name}: {e}") | |
processed_images += 1 | |
gr.Progress(processed_images / total_images) | |
return calculate_metrics(labels, preds, false_positives, false_negatives) | |
def calculate_metrics(labels, preds, false_positives, false_negatives): | |
cm = confusion_matrix(labels, preds) | |
accuracy = accuracy_score(labels, preds) | |
roc_score = roc_auc_score(labels, preds) | |
report_html = format_classification_report(labels, preds) | |
fpr, tpr, _ = roc_curve(labels, preds) | |
roc_auc = auc(fpr, tpr) | |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) | |
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=LABELS).plot(cmap=plt.cm.Blues, ax=ax1) | |
ax1.set_title("Confusion Matrix") | |
ax2.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})') | |
ax2.plot([0, 1], [0, 1], color='gray', linestyle='--') | |
ax2.set_xlim([0.0, 1.0]) | |
ax2.set_ylim([0.0, 1.05]) | |
ax2.set_xlabel('False Positive Rate') | |
ax2.set_ylabel('True Positive Rate') | |
ax2.set_title('ROC Curve') | |
ax2.legend(loc="lower right") | |
plt.tight_layout() | |
fp_fn_html = create_fp_fn_html(false_positives, false_negatives) | |
return accuracy, roc_score, report_html, fig, fp_fn_html | |
def format_classification_report(labels, preds): | |
report_dict = classification_report(labels, preds, output_dict=True) | |
html = """ | |
<table class="report-table"> | |
<tr> | |
<th>Class</th> | |
<th>Precision</th> | |
<th>Recall</th> | |
<th>F1-Score</th> | |
<th>Support</th> | |
</tr> | |
""" | |
for class_name in ['0', '1']: | |
html += f""" | |
<tr> | |
<td>{class_name}</td> | |
<td>{report_dict[class_name]['precision']:.2f}</td> | |
<td>{report_dict[class_name]['recall']:.2f}</td> | |
<td>{report_dict[class_name]['f1-score']:.2f}</td> | |
<td>{report_dict[class_name]['support']}</td> | |
</tr> | |
""" | |
html += f""" | |
<tr> | |
<td>Accuracy</td> | |
<td colspan="3">{report_dict['accuracy']:.2f}</td> | |
<td>{report_dict['macro avg']['support']}</td> | |
</tr> | |
</table> | |
""" | |
return html | |
def create_fp_fn_html(false_positives, false_negatives): | |
html = "<div class='image-grid'>" | |
for img_name, img_data in false_positives + false_negatives: | |
html += f""" | |
<div class="image-item"> | |
<img src="data:image/jpeg;base64,{img_data}" alt="{img_name}"> | |
<p>{img_name}</p> | |
</div> | |
""" | |
return html | |
def generate_pdf(accuracy, roc_score, report_html, confusion_matrix_plot): | |
buffer = BytesIO() | |
c = canvas.Canvas(buffer, pagesize=letter) | |
# Add content to PDF | |
c.drawString(100, 750, f"Model Results") | |
c.drawString(100, 730, f"Accuracy: {accuracy:.2f}") | |
c.drawString(100, 710, f"ROC Score: {roc_score:.2f}") | |
y_position = 690 | |
# Convert report_html to plain text if it's HTML content | |
from bs4 import BeautifulSoup | |
soup = BeautifulSoup(report_html, "html.parser") | |
report_text = soup.get_text() | |
# Add each line of the report text | |
for line in report_text.splitlines(): | |
if y_position < 50: # Create a new page if space runs out | |
c.showPage() | |
y_position = 750 | |
c.drawString(100, y_position, line.strip()) | |
y_position -= 20 | |
# Save Confusion Matrix Plot as an Image and Add it to the PDF | |
img_buffer = BytesIO() | |
confusion_matrix_plot.savefig(img_buffer, format="png") | |
img_buffer.seek(0) | |
c.drawImage(img_buffer, 100, y_position - 250, width=400, height=300) | |
c.save() | |
buffer.seek(0) | |
return buffer | |
def load_url(url): | |
try: | |
urllib.request.urlretrieve(url, "temp_image.png") | |
image = Image.open("temp_image.png") | |
message = "Image Loaded" | |
except Exception as e: | |
image = None | |
message = f"Image not Found<br>Error: {e}" | |
return image, message | |
detector = AIDetector() | |
def create_gradio_interface(): | |
with gr.Blocks() as app: | |
gr.Markdown("""<center><h1>AI Image Detector</h1></center>""") | |
with gr.Tabs(): | |
with gr.Tab("Single Image Detection"): | |
with gr.Column(): | |
inp = gr.Image(type='pil') | |
in_url = gr.Textbox(label="Image URL") | |
with gr.Row(): | |
load_btn = gr.Button("Load URL") | |
btn = gr.Button("Detect AI") | |
message = gr.HTML() | |
with gr.Group(): | |
with gr.Box(): | |
gr.HTML( | |
"""<b>Model: Newhouse AI Image Detection Model v1.1</b><br> | |
<i>This model has been trained on flux dev, flux schnell, Stable Diffusion 1.2, SDXL, and 3.5. The changes from 1.0 are that the model has had additional fine tuning on solid background images</i>""" | |
) | |
output_html = gr.HTML() | |
output_label = gr.Label(label="Output") | |
with gr.Tab("Batch Image Processing"): | |
with gr.Accordion("Upload Zip File", open=False): | |
zip_file = gr.File( | |
label="Upload Zip (must contain 'real' and 'ai' folders)", | |
file_types=[".zip"], | |
file_count="single" | |
) | |
zip_process_btn = gr.Button("Process Zip", interactive=False) | |
with gr.Accordion("Upload Individual Files", open=False): | |
with gr.Row(): | |
ai_files = gr.File( | |
label="Upload AI Images", | |
file_types=["image"], | |
file_count="multiple" | |
) | |
real_files = gr.File( | |
label="Upload Real Images", | |
file_types=["image"], | |
file_count="multiple" | |
) | |
individual_process_btn = gr.Button("Process Individual Files", interactive=False) | |
with gr.Group(): | |
gr.Markdown(f"### Newhouse AI Image Detection Model v1.0") | |
output_acc = gr.Label(label="Accuracy") | |
output_roc = gr.Label(label="ROC Score") | |
output_report = gr.HTML(label="Classification Report") | |
output_plots = gr.Plot(label="Confusion Matrix and ROC Curve") | |
output_fp_fn = gr.HTML(label="False Positives and Negatives") | |
download_pdf_btn = gr.Button("Download Results as PDF") | |
pdf_output = gr.File(label="Download PDF", visible=False) | |
reset_btn = gr.Button("Reset") | |
load_btn.click(load_url, in_url, [inp, message]) | |
btn.click( | |
lambda img: detector.predict(img), | |
inp, | |
[output_html, output_label] | |
) | |
def enable_zip_btn(file): | |
return gr.Button.update(interactive=file is not None) | |
def enable_individual_btn(ai_files, real_files): | |
return gr.Button.update(interactive=(ai_files is not None and real_files is not None)) | |
zip_file.upload(enable_zip_btn, zip_file, zip_process_btn) | |
ai_files.upload(enable_individual_btn, [ai_files, real_files], individual_process_btn) | |
real_files.upload(enable_individual_btn, [ai_files, real_files], individual_process_btn) | |
zip_process_btn.click( | |
process_zip, | |
zip_file, | |
[output_acc, output_roc, output_report, output_plots, output_fp_fn] | |
) | |
individual_process_btn.click( | |
process_files, | |
[ai_files, real_files], | |
[output_acc, output_roc, output_report, output_plots, output_fp_fn] | |
) | |
def on_download_pdf(accuracy, roc_score, report_html, confusion_matrix_plot): | |
pdf_buffer = generate_pdf(accuracy, roc_score, report_html, confusion_matrix_plot) | |
pdf_buffer.seek(0) | |
return pdf_buffer | |
download_pdf_btn.click( | |
on_download_pdf, | |
inputs=[output_acc, output_roc, output_report, output_plots], | |
outputs=pdf_output | |
) | |
def reset_interface(): | |
return [ | |
None, None, None, None, None, | |
gr.Button.update(interactive=False), | |
gr.Button.update(interactive=False), | |
None, None, None, None, None | |
] | |
reset_btn.click( | |
reset_interface, | |
inputs=None, | |
outputs=[ | |
zip_file, ai_files, real_files, | |
output_acc, output_roc, output_report, output_plots, output_fp_fn, | |
zip_process_btn, individual_process_btn | |
] | |
) | |
return app | |
if __name__ == "__main__": | |
app = create_gradio_interface() | |
app.launch( | |
show_api=False, | |
max_threads=24, | |
show_error=True | |
) |