ImageDetector / app.py
cmckinle's picture
Update app.py
c1f19b9 verified
raw
history blame
12.3 kB
import gradio as gr
import torch
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline
import os
import zipfile
import shutil
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report, roc_curve, auc, ConfusionMatrixDisplay
from PIL import Image
import tempfile
import numpy as np
import urllib.request
import base64
from io import BytesIO
MODEL_NAME = "cmckinle/sdxl-flux-detector"
LABELS = ["AI", "Real"]
class AIDetector:
def __init__(self):
self.pipe = pipeline("image-classification", MODEL_NAME)
self.feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
self.model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
@staticmethod
def softmax(vector):
e = np.exp(vector - np.max(vector))
return e / e.sum()
def predict(self, image):
inputs = self.feature_extractor(image, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
probabilities = self.softmax(logits.numpy())
prediction = logits.argmax(-1).item()
label = LABELS[prediction]
results = {label: float(prob) for label, prob in zip(LABELS, probabilities[0])}
return label, results
def process_zip(zip_file):
temp_dir = tempfile.mkdtemp()
try:
# Validate zip structure
with zipfile.ZipFile(zip_file.name, 'r') as z:
file_list = z.namelist()
if not ('real/' in file_list and 'ai/' in file_list):
raise ValueError("Zip file must contain 'real' and 'ai' folders")
z.extractall(temp_dir)
labels, preds, images = [], [], []
false_positives, false_negatives = [], []
detector = AIDetector()
total_images = sum(len(files) for _, _, files in os.walk(temp_dir))
processed_images = 0
for folder_name, ground_truth_label in [('real', 1), ('ai', 0)]:
folder_path = os.path.join(temp_dir, folder_name)
if not os.path.exists(folder_path):
raise ValueError(f"Folder not found: {folder_path}")
for img_name in os.listdir(folder_path):
img_path = os.path.join(folder_path, img_name)
try:
with Image.open(img_path).convert("RGB") as img:
_, prediction = detector.predict(img)
pred_label = 0 if prediction["AI"] > prediction["Real"] else 1
preds.append(pred_label)
labels.append(ground_truth_label)
images.append(img_name)
# Collect false positives and false negatives with image data
if pred_label != ground_truth_label:
with open(img_path, "rb") as img_file:
img_data = base64.b64encode(img_file.read()).decode()
if pred_label == 1 and ground_truth_label == 0:
false_positives.append((img_name, img_data))
elif pred_label == 0 and ground_truth_label == 1:
false_negatives.append((img_name, img_data))
except Exception as e:
print(f"Error processing image {img_name}: {e}")
processed_images += 1
gr.Progress(processed_images / total_images)
return evaluate_model(labels, preds, false_positives, false_negatives)
except Exception as e:
raise gr.Error(f"Error processing zip file: {str(e)}")
finally:
shutil.rmtree(temp_dir)
def format_classification_report(labels, preds):
# Convert the report string to a dictionary
report_dict = classification_report(labels, preds, output_dict=True)
# Create an HTML table with updated CSS
html = """
<style>
.report-table {
border-collapse: collapse;
width: 100%;
font-family: Arial, sans-serif;
}
.report-table th, .report-table td {
border: 1px solid;
padding: 8px;
text-align: center;
}
.report-table th {
font-weight: bold;
}
.report-table tr:nth-child(even) {
background-color: rgba(0, 0, 0, 0.05);
}
@media (prefers-color-scheme: dark) {
.report-table {
color: #e0e0e0;
background-color: #2d2d2d;
}
.report-table th, .report-table td {
border-color: #555;
}
.report-table th {
background-color: #3d3d3d;
}
.report-table tr:nth-child(even) {
background-color: #333;
}
.report-table tr:hover {
background-color: #3a3a3a;
}
}
@media (prefers-color-scheme: light) {
.report-table {
color: #333333;
background-color: #ffffff;
}
.report-table th, .report-table td {
border-color: #ddd;
}
.report-table th {
background-color: #f2f2f2;
}
.report-table tr:nth-child(even) {
background-color: #f9f9f9;
}
.report-table tr:hover {
background-color: #f5f5f5;
}
}
</style>
<table class="report-table">
<tr>
<th>Class</th>
<th>Precision</th>
<th>Recall</th>
<th>F1-Score</th>
<th>Support</th>
</tr>
"""
# Add rows for each class
for class_name in ['0', '1']:
html += f"""
<tr>
<td>{class_name}</td>
<td>{report_dict[class_name]['precision']:.2f}</td>
<td>{report_dict[class_name]['recall']:.2f}</td>
<td>{report_dict[class_name]['f1-score']:.2f}</td>
<td>{report_dict[class_name]['support']}</td>
</tr>
"""
# Add summary rows
html += f"""
<tr>
<td>Accuracy</td>
<td colspan="3">{report_dict['accuracy']:.2f}</td>
<td>{report_dict['macro avg']['support']}</td>
</tr>
<tr>
<td>Macro Avg</td>
<td>{report_dict['macro avg']['precision']:.2f}</td>
<td>{report_dict['macro avg']['recall']:.2f}</td>
<td>{report_dict['macro avg']['f1-score']:.2f}</td>
<td>{report_dict['macro avg']['support']}</td>
</tr>
<tr>
<td>Weighted Avg</td>
<td>{report_dict['weighted avg']['precision']:.2f}</td>
<td>{report_dict['weighted avg']['recall']:.2f}</td>
<td>{report_dict['weighted avg']['f1-score']:.2f}</td>
<td>{report_dict['weighted avg']['support']}</td>
</tr>
</table>
"""
return html
def evaluate_model(labels, preds, false_positives, false_negatives):
cm = confusion_matrix(labels, preds)
accuracy = accuracy_score(labels, preds)
roc_score = roc_auc_score(labels, preds)
report_html = format_classification_report(labels, preds)
fpr, tpr, _ = roc_curve(labels, preds)
roc_auc = auc(fpr, tpr)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=LABELS).plot(cmap=plt.cm.Blues, ax=ax1)
ax1.set_title("Confusion Matrix")
ax2.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
ax2.plot([0, 1], [0, 1], color='gray', linestyle='--')
ax2.set_xlim([0.0, 1.0])
ax2.set_ylim([0.0, 1.05])
ax2.set_xlabel('False Positive Rate')
ax2.set_ylabel('True Positive Rate')
ax2.set_title('ROC Curve')
ax2.legend(loc="lower right")
plt.tight_layout()
# Create HTML for false positives and negatives with images
fp_fn_html = """
<style>
.image-grid {
display: flex;
flex-wrap: wrap;
gap: 10px;
}
.image-item {
display: flex;
flex-direction: column;
align-items: center;
}
.image-item img {
max-width: 200px;
max-height: 200px;
}
</style>
"""
fp_fn_html += "<h3>False Positives (AI images classified as Real):</h3>"
fp_fn_html += '<div class="image-grid">'
for img_name, img_data in false_positives:
fp_fn_html += f'''
<div class="image-item">
<img src="data:image/jpeg;base64,{img_data}" alt="{img_name}">
<p>{img_name}</p>
</div>
'''
fp_fn_html += '</div>'
fp_fn_html += "<h3>False Negatives (Real images classified as AI):</h3>"
fp_fn_html += '<div class="image-grid">'
for img_name, img_data in false_negatives:
fp_fn_html += f'''
<div class="image-item">
<img src="data:image/jpeg;base64,{img_data}" alt="{img_name}">
<p>{img_name}</p>
</div>
'''
fp_fn_html += '</div>'
return accuracy, roc_score, report_html, fig, fp_fn_html
def load_url(url):
try:
urllib.request.urlretrieve(url, "temp_image.png")
image = Image.open("temp_image.png")
message = "Image Loaded"
except Exception as e:
image = None
message = f"Image not Found<br>Error: {e}"
return image, message
detector = AIDetector()
def create_gradio_interface():
with gr.Blocks() as app:
gr.Markdown("""<center><h1>AI Image Detector<br><h4>(Test Demo - accuracy varies by model)</h4></h1></center>""")
with gr.Tabs():
with gr.Tab("Single Image Detection"):
with gr.Column():
inp = gr.Image(type='pil')
in_url = gr.Textbox(label="Image URL")
with gr.Row():
load_btn = gr.Button("Load URL")
btn = gr.Button("Detect AI")
message = gr.HTML()
with gr.Group():
with gr.Box():
gr.HTML(f"""<b>Testing on Model: <a href='https://huggingface.co/{MODEL_NAME}'>{MODEL_NAME}</a></b>""")
output_html = gr.HTML()
output_label = gr.Label(label="Output")
with gr.Tab("Batch Image Processing"):
zip_file = gr.File(
label="Upload Zip (must contain 'real' and 'ai' folders)",
file_types=[".zip"],
file_count="single",
max_file_size=1024 # 1024 MB (1 GB)
)
batch_btn = gr.Button("Process Batch", interactive=False)
with gr.Group():
gr.Markdown(f"### Results for {MODEL_NAME}")
output_acc = gr.Label(label="Accuracy")
output_roc = gr.Label(label="ROC Score")
output_report = gr.HTML(label="Classification Report")
output_plots = gr.Plot(label="Confusion Matrix and ROC Curve")
output_fp_fn = gr.HTML(label="False Positives and Negatives")
load_btn.click(load_url, in_url, [inp, message])
btn.click(
lambda img: detector.predict(img),
inp,
[output_html, output_label]
)
def enable_batch_btn(file):
return gr.Button.update(interactive=file is not None)
zip_file.upload(
enable_batch_btn,
zip_file,
batch_btn
)
batch_btn.click(
process_zip,
zip_file,
[output_acc, output_roc, output_report, output_plots, output_fp_fn],
api_name="batch_process"
)
return app
if __name__ == "__main__":
app = create_gradio_interface()
app.launch(show_api=False, max_threads=24)