File size: 5,635 Bytes
6364b8e
 
050a6c5
e88ec7e
 
 
 
773268b
06b2f35
e88ec7e
 
 
06b2f35
773268b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f88c19d
773268b
e88ec7e
773268b
 
 
 
af90ec3
773268b
af90ec3
275549c
af90ec3
 
 
 
 
773268b
 
 
af90ec3
 
 
 
f88c19d
275549c
773268b
 
e88ec7e
be0d928
 
 
 
c07149a
be0d928
 
 
773268b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c07149a
275549c
06b2f35
 
773268b
 
 
06b2f35
a870a21
773268b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a38dfd7
 
 
 
773268b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import gradio as gr
import torch
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline
import os
import zipfile
import shutil
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report, roc_curve, auc, ConfusionMatrixDisplay
from PIL import Image
import tempfile
import numpy as np
import urllib.request

MODEL_NAME = "cmckinle/sdxl-flux-detector"
LABELS = ["AI", "Real"]

class AIDetector:
    def __init__(self):
        self.pipe = pipeline("image-classification", MODEL_NAME)
        self.feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
        self.model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)

    @staticmethod
    def softmax(vector):
        e = np.exp(vector - np.max(vector))
        return e / e.sum()

    def predict(self, image):
        inputs = self.feature_extractor(image, return_tensors="pt")
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits
            probabilities = self.softmax(logits.numpy())
        
        prediction = logits.argmax(-1).item()
        label = LABELS[prediction]
        
        results = {label: float(prob) for label, prob in zip(LABELS, probabilities[0])}
        
        return label, results

def process_zip(zip_file):
    temp_dir = tempfile.mkdtemp()
    with zipfile.ZipFile(zip_file.name, 'r') as z:
        z.extractall(temp_dir)
    
    labels, preds, images = [], [], []
    detector = AIDetector()
    
    for folder_name, ground_truth_label in [('real', 1), ('ai', 0)]:
        folder_path = os.path.join(temp_dir, folder_name)
        if not os.path.exists(folder_path):
            print(f"Folder not found: {folder_path}")
            continue
        for img_name in os.listdir(folder_path):
            img_path = os.path.join(folder_path, img_name)
            try:
                img = Image.open(img_path).convert("RGB")
                _, prediction = detector.predict(img)
                pred_label = 0 if prediction["AI"] > prediction["Real"] else 1
                
                preds.append(pred_label)
                labels.append(ground_truth_label)
                images.append(img_name)
            except Exception as e:
                print(f"Error processing image {img_name}: {e}")
    
    shutil.rmtree(temp_dir)
    return evaluate_model(labels, preds)

def evaluate_model(labels, preds):
    cm = confusion_matrix(labels, preds)
    accuracy = accuracy_score(labels, preds)
    roc_score = roc_auc_score(labels, preds)
    report = classification_report(labels, preds)
    fpr, tpr, _ = roc_curve(labels, preds)
    roc_auc = auc(fpr, tpr)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=LABELS).plot(cmap=plt.cm.Blues, ax=ax1)
    ax1.set_title("Confusion Matrix")
    
    ax2.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
    ax2.plot([0, 1], [0, 1], color='gray', linestyle='--')
    ax2.set_xlim([0.0, 1.0])
    ax2.set_ylim([0.0, 1.05])
    ax2.set_xlabel('False Positive Rate')
    ax2.set_ylabel('True Positive Rate')
    ax2.set_title('ROC Curve')
    ax2.legend(loc="lower right")
    
    plt.tight_layout()

    return accuracy, roc_score, report, fig

def load_url(url):
    try:
        urllib.request.urlretrieve(url, "temp_image.png")
        image = Image.open("temp_image.png")
        message = "Image Loaded"
    except Exception as e:
        image = None
        message = f"Image not Found<br>Error: {e}"
    return image, message

detector = AIDetector()

def create_gradio_interface():
    with gr.Blocks() as app:
        gr.Markdown("""<center><h1>AI Image Detector<br><h4>(Test Demo - accuracy varies by model)</h4></h1></center>""")

        with gr.Tabs():
            with gr.Tab("Single Image Detection"):
                with gr.Column():
                    inp = gr.Image(type='pil')
                    in_url = gr.Textbox(label="Image URL")
                    with gr.Row():
                        load_btn = gr.Button("Load URL")
                        btn = gr.Button("Detect AI")
                    message = gr.HTML()

                with gr.Group():
                    with gr.Box():
                        gr.HTML(f"""<b>Testing on Model: <a href='https://huggingface.co/{MODEL_NAME}'>{MODEL_NAME}</a></b>""")
                        output_html = gr.HTML()
                        output_label = gr.Label(label="Output")

            with gr.Tab("Batch Image Processing"):
                zip_file = gr.File(label="Upload Zip (two folders: real, ai)")
                batch_btn = gr.Button("Process Batch")

                with gr.Group():
                    gr.Markdown(f"### Results for {MODEL_NAME}")
                    output_acc = gr.Label(label="Accuracy")
                    output_roc = gr.Label(label="ROC Score")
                    output_report = gr.Textbox(label="Classification Report", lines=10)
                    output_plots = gr.Plot(label="Confusion Matrix and ROC Curve")

        load_btn.click(load_url, in_url, [inp, message])
        btn.click(
            lambda img: detector.predict(img),
            inp,
            [output_html, output_label]
        )

        batch_btn.click(
            process_zip,
            zip_file,
            [output_acc, output_roc, output_report, output_plots]
        )

    return app

if __name__ == "__main__":
    app = create_gradio_interface()
    app.launch(show_api=False, max_threads=24)