File size: 7,140 Bytes
6364b8e
 
050a6c5
e88ec7e
 
 
 
 
06b2f35
 
e88ec7e
 
 
 
 
 
06b2f35
f88c19d
 
 
050a6c5
a870a21
e88ec7e
a870a21
e88ec7e
a870a21
 
 
 
f88c19d
 
1a642a1
f88c19d
1a642a1
 
f88c19d
1a642a1
 
 
f88c19d
1a642a1
f88c19d
 
 
1a642a1
f88c19d
1a642a1
f88c19d
1a642a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e88ec7e
 
f88c19d
e88ec7e
 
 
 
af90ec3
f88c19d
af90ec3
 
 
 
 
 
275549c
af90ec3
 
 
 
 
f88c19d
 
af90ec3
 
 
 
 
f88c19d
275549c
f88c19d
af90ec3
e88ec7e
be0d928
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f88c19d
e88ec7e
 
f88c19d
 
e88ec7e
f88c19d
275549c
1a642a1
06b2f35
 
a870a21
06b2f35
 
 
a870a21
 
 
050a6c5
 
 
a870a21
 
 
e88ec7e
 
050a6c5
 
 
a870a21
050a6c5
a870a21
050a6c5
 
 
5ca31bc
e88ec7e
5ca31bc
e88ec7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f88c19d
 
 
 
e88ec7e
 
 
 
f88c19d
e88ec7e
 
 
 
 
 
 
f88c19d
 
 
 
 
 
 
1a642a1
e88ec7e
cf9edec
f88c19d
e88ec7e
f88c19d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import gradio as gr
import torch
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline
import os
import zipfile
import shutil
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report, roc_curve, auc
from PIL import Image
import uuid
import tempfile
import pandas as pd
from numpy import exp
import numpy as np
from sklearn.metrics import ConfusionMatrixDisplay
import urllib.request

# Define model
model = "cmckinle/sdxl-flux-detector"
pipe = pipeline("image-classification", model)

fin_sum = []
uid = uuid.uuid4()

# Softmax function
def softmax(vector):
    e = exp(vector - vector.max())  # for numerical stability
    return e / e.sum()

# Single image classification function
def image_classifier(image):
    labels = ["AI", "Real"]
    outputs = pipe(image)
    results = {}
    for idx, result in enumerate(outputs):
        results[labels[idx]] = float(outputs[idx]['score'])
    fin_sum.append(results)
    return results

def aiornot(image):
    labels = ["AI", "Real"]
    feature_extractor = AutoFeatureExtractor.from_pretrained(model)
    model_cls = AutoModelForImageClassification.from_pretrained(model)
    input = feature_extractor(image, return_tensors="pt")
    with torch.no_grad():
        outputs = model_cls(**input)
        logits = outputs.logits
        probability = softmax(logits)
        px = pd.DataFrame(probability.numpy())
    prediction = logits.argmax(-1).item()
    label = labels[prediction]

    html_out = f"""
    <h1>This image is likely: {label}</h1><br><h3>
    Probabilities:<br>
    Real: {float(px[1][0]):.4f}<br>
    AI: {float(px[0][0]):.4f}"""
    
    results = {
        "Real": float(px[1][0]),
        "AI": float(px[0][0])
    }
    fin_sum.append(results)
    return gr.HTML.update(html_out), results

# Function to extract images from zip
def extract_zip(zip_file):
    temp_dir = tempfile.mkdtemp()
    with zipfile.ZipFile(zip_file, 'r') as z:
        z.extractall(temp_dir)
    return temp_dir

# Function to classify images in a folder
def classify_images(image_dir):
    images = []
    labels = []
    preds = []
    for folder_name, ground_truth_label in [('real', 1), ('ai', 0)]:
        folder_path = os.path.join(image_dir, folder_name)
        if not os.path.exists(folder_path):
            print(f"Folder not found: {folder_path}")
            continue
        for img_name in os.listdir(folder_path):
            img_path = os.path.join(folder_path, img_name)
            try:
                img = Image.open(img_path).convert("RGB")
                pred = pipe(img)
                pred_label = 0 if pred[0]['label'] == 'AI' else 1

                preds.append(pred_label)
                labels.append(ground_truth_label)
                images.append(img_name)
            except Exception as e:
                print(f"Error processing image {img_name}: {e}")
    
    print(f"Processed {len(images)} images")
    return labels, preds, images

# Function to generate evaluation metrics
def evaluate_model(labels, preds):
    cm = confusion_matrix(labels, preds)
    accuracy = accuracy_score(labels, preds)
    roc_score = roc_auc_score(labels, preds)
    report = classification_report(labels, preds)
    fpr, tpr, _ = roc_curve(labels, preds)
    roc_auc = auc(fpr, tpr)

    fig, ax = plt.subplots()
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["AI", "Real"])
    disp.plot(cmap=plt.cm.Blues, ax=ax)
    plt.close(fig)

    fig_roc, ax_roc = plt.subplots()
    ax_roc.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
    ax_roc.plot([0, 1], [0, 1], color='gray', linestyle='--')
    ax_roc.set_xlim([0.0, 1.0])
    ax_roc.set_ylim([0.0, 1.05])
    ax_roc.set_xlabel('False Positive Rate')
    ax_roc.set_ylabel('True Positive Rate')
    ax_roc.set_title('Receiver Operating Characteristic (ROC) Curve')
    ax_roc.legend(loc="lower right")
    plt.close(fig_roc)

    return accuracy, roc_score, report, fig, fig_roc

# Batch processing
def process_zip(zip_file):
    extracted_dir = extract_zip(zip_file.name)
    labels, preds, images = classify_images(extracted_dir)
    accuracy, roc_score, report, cm_fig, roc_fig = evaluate_model(labels, preds)
    shutil.rmtree(extracted_dir)  # Clean up extracted files
    return accuracy, roc_score, report, cm_fig, roc_fig

# Single image section
def load_url(url):
    try:
        urllib.request.urlretrieve(f'{url}', f"{uid}tmp_im.png")
        image = Image.open(f"{uid}tmp_im.png")
        mes = "Image Loaded"
    except Exception as e:
        image = None
        mes = f"Image not Found<br>Error: {e}"
    return image, mes

def tot_prob():
    try:
        fin_out = sum([result["Real"] for result in fin_sum]) / len(fin_sum)
        fin_sub = 1 - fin_out
        out = {
            "Real": f"{fin_out:.4f}",
            "AI": f"{fin_sub:.4f}"
        }
        return out
    except Exception as e:
        print(e)
        return None

def fin_clear():
    fin_sum.clear()
    return None

# Set up Gradio app
with gr.Blocks() as app:
    gr.Markdown("""<center><h1>AI Image Detector<br><h4>(Test Demo - accuracy varies by model)</h4></h1></center>""")

    with gr.Tabs():
        # Tab for single image detection
        with gr.Tab("Single Image Detection"):
            with gr.Column():
                inp = gr.Image(type='pil')
                in_url = gr.Textbox(label="Image URL")
                with gr.Row():
                    load_btn = gr.Button("Load URL")
                    btn = gr.Button("Detect AI")
                mes = gr.HTML("""""")

            with gr.Group():
                with gr.Row():
                    fin = gr.Label(label="Final Probability")
                with gr.Row():
                    with gr.Box():
                        gr.HTML(f"""<b>Testing on Model: <a href='https://huggingface.co/{model}'>{model}</a></b>""")
                        outp = gr.HTML("""""")
                        n_out = gr.Label(label="Output")

            btn.click(fin_clear, None, fin, show_progress=False)
            load_btn.click(load_url, in_url, [inp, mes])

            btn.click(aiornot, [inp], [outp, n_out]).then(
                tot_prob, None, fin, show_progress=False)

        # Tab for batch processing
        with gr.Tab("Batch Image Processing"):
            zip_file = gr.File(label="Upload Zip (two folders: real, ai)")
            batch_btn = gr.Button("Process Batch")

            with gr.Group():
                gr.Markdown(f"### Results for {model}")
                output_acc = gr.Label(label="Accuracy")
                output_roc = gr.Label(label="ROC Score")
                output_report = gr.Textbox(label="Classification Report", lines=10)
                output_cm = gr.Plot(label="Confusion Matrix")
                output_roc_plot = gr.Plot(label="ROC Curve")

    # Connect batch processing
    batch_btn.click(process_zip, zip_file, 
                    [output_acc, output_roc, output_report, output_cm, output_roc_plot])

app.launch(show_api=False, max_threads=24)