ImageDetector / app.py
cmckinle's picture
Update app.py
91e34e4 verified
raw
history blame
13.2 kB
import gradio as gr
import torch
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline
import os
import zipfile
import shutil
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report, roc_curve, auc, ConfusionMatrixDisplay
from PIL import Image
import tempfile
import numpy as np
import urllib.request
import base64
from io import BytesIO
import logging
from tqdm import tqdm
# Set up logging
logging.basicConfig(filename='app.log', level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s')
MODEL_NAME = "cmckinle/sdxl-flux-detector"
LABELS = ["AI", "Real"]
class AIDetector:
def __init__(self):
self.pipe = pipeline("image-classification", MODEL_NAME)
self.feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
self.model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
@staticmethod
def softmax(vector):
e = np.exp(vector - np.max(vector))
return e / e.sum()
def predict(self, image):
inputs = self.feature_extractor(image, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
probabilities = self.softmax(logits.numpy())
prediction = logits.argmax(-1).item()
label = LABELS[prediction]
results = {label: float(prob) for label, prob in zip(LABELS, probabilities[0])}
return label, results
def custom_upload_handler(file):
try:
logging.info(f"Starting upload of file: {file.name}")
file_size = os.path.getsize(file.name)
logging.info(f"File size: {file_size} bytes")
# Read and process the file in chunks
chunk_size = 1024 * 1024 # 1MB chunks
total_chunks = file_size // chunk_size + (1 if file_size % chunk_size > 0 else 0)
with open(file.name, 'rb') as f:
for chunk in tqdm(range(total_chunks), desc="Uploading"):
data = f.read(chunk_size)
if not data:
break
logging.debug(f"Processed chunk {chunk+1} of {total_chunks}")
logging.info("File upload completed successfully")
return file
except Exception as e:
logging.error(f"Error during file upload: {str(e)}")
raise gr.Error(f"Upload failed: {str(e)}")
def process_zip(zip_file):
temp_dir = tempfile.mkdtemp()
try:
logging.info(f"Starting to process zip file: {zip_file.name}")
# Validate zip structure
with zipfile.ZipFile(zip_file.name, 'r') as z:
file_list = z.namelist()
if not ('real/' in file_list and 'ai/' in file_list):
raise ValueError("Zip file must contain 'real' and 'ai' folders")
z.extractall(temp_dir)
labels, preds, images = [], [], []
false_positives, false_negatives = [], []
detector = AIDetector()
total_images = sum(len(files) for _, _, files in os.walk(temp_dir))
processed_images = 0
for folder_name, ground_truth_label in [('real', 1), ('ai', 0)]:
folder_path = os.path.join(temp_dir, folder_name)
if not os.path.exists(folder_path):
raise ValueError(f"Folder not found: {folder_path}")
for img_name in os.listdir(folder_path):
img_path = os.path.join(folder_path, img_name)
try:
with Image.open(img_path).convert("RGB") as img:
_, prediction = detector.predict(img)
pred_label = 0 if prediction["AI"] > prediction["Real"] else 1
preds.append(pred_label)
labels.append(ground_truth_label)
images.append(img_name)
# Collect false positives and false negatives with image data
if pred_label != ground_truth_label:
with open(img_path, "rb") as img_file:
img_data = base64.b64encode(img_file.read()).decode()
if pred_label == 1 and ground_truth_label == 0:
false_positives.append((img_name, img_data))
elif pred_label == 0 and ground_truth_label == 1:
false_negatives.append((img_name, img_data))
except Exception as e:
logging.error(f"Error processing image {img_name}: {e}")
processed_images += 1
gr.Progress(processed_images / total_images)
logging.info("Zip file processing completed successfully")
return evaluate_model(labels, preds, false_positives, false_negatives)
except Exception as e:
logging.error(f"Error processing zip file: {str(e)}")
raise gr.Error(f"Error processing zip file: {str(e)}")
finally:
shutil.rmtree(temp_dir)
def format_classification_report(labels, preds):
# Convert the report string to a dictionary
report_dict = classification_report(labels, preds, output_dict=True)
# Create an HTML table with updated CSS
html = """
<style>
.report-table {
border-collapse: collapse;
width: 100%;
font-family: Arial, sans-serif;
}
.report-table th, .report-table td {
border: 1px solid;
padding: 8px;
text-align: center;
}
.report-table th {
font-weight: bold;
}
.report-table tr:nth-child(even) {
background-color: rgba(0, 0, 0, 0.05);
}
@media (prefers-color-scheme: dark) {
.report-table {
color: #e0e0e0;
background-color: #2d2d2d;
}
.report-table th, .report-table td {
border-color: #555;
}
.report-table th {
background-color: #3d3d3d;
}
.report-table tr:nth-child(even) {
background-color: #333;
}
.report-table tr:hover {
background-color: #3a3a3a;
}
}
@media (prefers-color-scheme: light) {
.report-table {
color: #333333;
background-color: #ffffff;
}
.report-table th, .report-table td {
border-color: #ddd;
}
.report-table th {
background-color: #f2f2f2;
}
.report-table tr:nth-child(even) {
background-color: #f9f9f9;
}
.report-table tr:hover {
background-color: #f5f5f5;
}
}
</style>
<table class="report-table">
<tr>
<th>Class</th>
<th>Precision</th>
<th>Recall</th>
<th>F1-Score</th>
<th>Support</th>
</tr>
"""
# Add rows for each class
for class_name in ['0', '1']:
html += f"""
<tr>
<td>{class_name}</td>
<td>{report_dict[class_name]['precision']:.2f}</td>
<td>{report_dict[class_name]['recall']:.2f}</td>
<td>{report_dict[class_name]['f1-score']:.2f}</td>
<td>{report_dict[class_name]['support']}</td>
</tr>
"""
# Add summary rows
html += f"""
<tr>
<td>Accuracy</td>
<td colspan="3">{report_dict['accuracy']:.2f}</td>
<td>{report_dict['macro avg']['support']}</td>
</tr>
<tr>
<td>Macro Avg</td>
<td>{report_dict['macro avg']['precision']:.2f}</td>
<td>{report_dict['macro avg']['recall']:.2f}</td>
<td>{report_dict['macro avg']['f1-score']:.2f}</td>
<td>{report_dict['macro avg']['support']}</td>
</tr>
<tr>
<td>Weighted Avg</td>
<td>{report_dict['weighted avg']['precision']:.2f}</td>
<td>{report_dict['weighted avg']['recall']:.2f}</td>
<td>{report_dict['weighted avg']['f1-score']:.2f}</td>
<td>{report_dict['weighted avg']['support']}</td>
</tr>
</table>
"""
return html
def evaluate_model(labels, preds, false_positives, false_negatives):
cm = confusion_matrix(labels, preds)
accuracy = accuracy_score(labels, preds)
roc_score = roc_auc_score(labels, preds)
report_html = format_classification_report(labels, preds)
fpr, tpr, _ = roc_curve(labels, preds)
roc_auc = auc(fpr, tpr)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=LABELS).plot(cmap=plt.cm.Blues, ax=ax1)
ax1.set_title("Confusion Matrix")
ax2.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
ax2.plot([0, 1], [0, 1], color='gray', linestyle='--')
ax2.set_xlim([0.0, 1.0])
ax2.set_ylim([0.0, 1.05])
ax2.set_xlabel('False Positive Rate')
ax2.set_ylabel('True Positive Rate')
ax2.set_title('ROC Curve')
ax2.legend(loc="lower right")
plt.tight_layout()
# Create HTML for false positives and negatives with images
fp_fn_html = """
<style>
.image-grid {
display: flex;
flex-wrap: wrap;
gap: 10px;
}
.image-item {
display: flex;
flex-direction: column;
align-items: center;
}
.image-item img {
max-width: 200px;
max-height: 200px;
}
</style>
"""
fp_fn_html += "<h3>False Positives (AI images classified as Real):</h3>"
fp_fn_html += '<div class="image-grid">'
for img_name, img_data in false_positives:
fp_fn_html += f'''
<div class="image-item">
<img src="data:image/jpeg;base64,{img_data}" alt="{img_name}">
<p>{img_name}</p>
</div>
'''
fp_fn_html += '</div>'
fp_fn_html += "<h3>False Negatives (Real images classified as AI):</h3>"
fp_fn_html += '<div class="image-grid">'
for img_name, img_data in false_negatives:
fp_fn_html += f'''
<div class="image-item">
<img src="data:image/jpeg;base64,{img_data}" alt="{img_name}">
<p>{img_name}</p>
</div>
'''
fp_fn_html += '</div>'
return accuracy, roc_score, report_html, fig, fp_fn_html
def load_url(url):
try:
urllib.request.urlretrieve(url, "temp_image.png")
image = Image.open("temp_image.png")
message = "Image Loaded"
except Exception as e:
image = None
message = f"Image not Found<br>Error: {e}"
return image, message
detector = AIDetector()
def create_gradio_interface():
with gr.Blocks() as app:
gr.Markdown("""<center><h1>AI Image Detector<br><h4>(Test Demo - accuracy varies by model)</h4></h1></center>""")
with gr.Tabs():
with gr.Tab("Single Image Detection"):
with gr.Column():
inp = gr.Image(type='pil')
in_url = gr.Textbox(label="Image URL")
with gr.Row():
load_btn = gr.Button("Load URL")
btn = gr.Button("Detect AI")
message = gr.HTML()
with gr.Group():
with gr.Box():
gr.HTML(f"""<b>Testing on Model: <a href='https://huggingface.co/{MODEL_NAME}'>{MODEL_NAME}</a></b>""")
output_html = gr.HTML()
output_label = gr.Label(label="Output")
with gr.Tab("Batch Image Processing"):
zip_file = gr.File(
label="Upload Zip (must contain 'real' and 'ai' folders)",
file_types=[".zip"],
file_count="single",
max_file_size=1024 * 10, # 10240 MB (10 GB)
preprocess=custom_upload_handler
)
batch_btn = gr.Button("Process Batch", interactive=False)
with gr.Group():
gr.Markdown(f"### Results for {MODEL_NAME}")
output_acc = gr.Label(label="Accuracy")
output_roc = gr.Label(label="ROC Score")
output_report = gr.HTML(label="Classification Report")
output_plots = gr.Plot(label="Confusion Matrix and ROC Curve")
output_fp_fn = gr.HTML(label="False Positives and Negatives")
load_btn.click(load_url, in_url, [inp, message])
btn.click(
lambda img: detector.predict(img),
inp,
[output_html, output_label]
)
def enable_batch_btn(file):
return gr.Button.update(interactive=file is not None)