Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -5,78 +5,69 @@ import os
|
|
5 |
import zipfile
|
6 |
import shutil
|
7 |
import matplotlib.pyplot as plt
|
8 |
-
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report, roc_curve, auc
|
9 |
from PIL import Image
|
10 |
import uuid
|
11 |
import tempfile
|
12 |
import pandas as pd
|
13 |
-
from numpy import exp
|
14 |
import numpy as np
|
15 |
-
from sklearn.metrics import ConfusionMatrixDisplay
|
16 |
import urllib.request
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
results
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
return gr.HTML.update(html_out), results
|
65 |
-
|
66 |
-
# Function to extract images from zip
|
67 |
-
def extract_zip(zip_file):
|
68 |
temp_dir = tempfile.mkdtemp()
|
69 |
-
with zipfile.ZipFile(zip_file, 'r') as z:
|
70 |
z.extractall(temp_dir)
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
images = []
|
76 |
-
labels = []
|
77 |
-
preds = []
|
78 |
for folder_name, ground_truth_label in [('real', 1), ('ai', 0)]:
|
79 |
-
folder_path = os.path.join(
|
80 |
if not os.path.exists(folder_path):
|
81 |
print(f"Folder not found: {folder_path}")
|
82 |
continue
|
@@ -84,19 +75,18 @@ def classify_images(image_dir):
|
|
84 |
img_path = os.path.join(folder_path, img_name)
|
85 |
try:
|
86 |
img = Image.open(img_path).convert("RGB")
|
87 |
-
|
88 |
-
pred_label = 0 if
|
89 |
-
|
90 |
preds.append(pred_label)
|
91 |
labels.append(ground_truth_label)
|
92 |
images.append(img_name)
|
93 |
except Exception as e:
|
94 |
print(f"Error processing image {img_name}: {e}")
|
95 |
|
96 |
-
|
97 |
-
return labels, preds
|
98 |
|
99 |
-
# Function to generate evaluation metrics
|
100 |
def evaluate_model(labels, preds):
|
101 |
cm = confusion_matrix(labels, preds)
|
102 |
accuracy = accuracy_score(labels, preds)
|
@@ -105,105 +95,91 @@ def evaluate_model(labels, preds):
|
|
105 |
fpr, tpr, _ = roc_curve(labels, preds)
|
106 |
roc_auc = auc(fpr, tpr)
|
107 |
|
108 |
-
fig,
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
plt.
|
123 |
-
|
124 |
-
return accuracy, roc_score, report, fig
|
125 |
-
|
126 |
-
# Batch processing
|
127 |
-
def process_zip(zip_file):
|
128 |
-
extracted_dir = extract_zip(zip_file.name)
|
129 |
-
labels, preds, images = classify_images(extracted_dir)
|
130 |
-
accuracy, roc_score, report, cm_fig, roc_fig = evaluate_model(labels, preds)
|
131 |
-
shutil.rmtree(extracted_dir) # Clean up extracted files
|
132 |
-
return accuracy, roc_score, report, cm_fig, roc_fig
|
133 |
|
134 |
-
# Single image section
|
135 |
def load_url(url):
|
136 |
try:
|
137 |
-
urllib.request.urlretrieve(
|
138 |
-
image = Image.open(
|
139 |
-
|
140 |
except Exception as e:
|
141 |
image = None
|
142 |
-
|
143 |
-
return image,
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
# Connect batch processing
|
206 |
-
batch_btn.click(process_zip, zip_file,
|
207 |
-
[output_acc, output_roc, output_report, output_cm, output_roc_plot])
|
208 |
-
|
209 |
-
app.launch(show_api=False, max_threads=24)
|
|
|
5 |
import zipfile
|
6 |
import shutil
|
7 |
import matplotlib.pyplot as plt
|
8 |
+
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report, roc_curve, auc, ConfusionMatrixDisplay
|
9 |
from PIL import Image
|
10 |
import uuid
|
11 |
import tempfile
|
12 |
import pandas as pd
|
|
|
13 |
import numpy as np
|
|
|
14 |
import urllib.request
|
15 |
|
16 |
+
MODEL_NAME = "cmckinle/sdxl-flux-detector"
|
17 |
+
LABELS = ["AI", "Real"]
|
18 |
+
|
19 |
+
class AIDetector:
|
20 |
+
def __init__(self):
|
21 |
+
self.pipe = pipeline("image-classification", MODEL_NAME)
|
22 |
+
self.feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
|
23 |
+
self.model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
|
24 |
+
self.results = []
|
25 |
+
|
26 |
+
@staticmethod
|
27 |
+
def softmax(vector):
|
28 |
+
e = np.exp(vector - np.max(vector))
|
29 |
+
return e / e.sum()
|
30 |
+
|
31 |
+
def classify_image(self, image):
|
32 |
+
outputs = self.pipe(image)
|
33 |
+
results = {label: float(output['score']) for label, output in zip(LABELS, outputs)}
|
34 |
+
self.results.append(results)
|
35 |
+
return results
|
36 |
+
|
37 |
+
def predict(self, image):
|
38 |
+
inputs = self.feature_extractor(image, return_tensors="pt")
|
39 |
+
with torch.no_grad():
|
40 |
+
outputs = self.model(**inputs)
|
41 |
+
logits = outputs.logits
|
42 |
+
probabilities = self.softmax(logits.numpy())
|
43 |
+
|
44 |
+
prediction = logits.argmax(-1).item()
|
45 |
+
label = LABELS[prediction]
|
46 |
+
|
47 |
+
results = {label: float(prob) for label, prob in zip(LABELS, probabilities[0])}
|
48 |
+
self.results.append(results)
|
49 |
+
|
50 |
+
return label, results
|
51 |
+
|
52 |
+
def get_total_probability(self):
|
53 |
+
if not self.results:
|
54 |
+
return None
|
55 |
+
avg_real_prob = sum(result["Real"] for result in self.results) / len(self.results)
|
56 |
+
return {"Real": f"{avg_real_prob:.4f}", "AI": f"{1 - avg_real_prob:.4f}"}
|
57 |
+
|
58 |
+
def clear_results(self):
|
59 |
+
self.results.clear()
|
60 |
+
|
61 |
+
def process_zip(zip_file):
|
|
|
|
|
|
|
|
|
62 |
temp_dir = tempfile.mkdtemp()
|
63 |
+
with zipfile.ZipFile(zip_file.name, 'r') as z:
|
64 |
z.extractall(temp_dir)
|
65 |
+
|
66 |
+
labels, preds, images = [], [], []
|
67 |
+
detector = AIDetector()
|
68 |
+
|
|
|
|
|
|
|
69 |
for folder_name, ground_truth_label in [('real', 1), ('ai', 0)]:
|
70 |
+
folder_path = os.path.join(temp_dir, folder_name)
|
71 |
if not os.path.exists(folder_path):
|
72 |
print(f"Folder not found: {folder_path}")
|
73 |
continue
|
|
|
75 |
img_path = os.path.join(folder_path, img_name)
|
76 |
try:
|
77 |
img = Image.open(img_path).convert("RGB")
|
78 |
+
_, prediction = detector.predict(img)
|
79 |
+
pred_label = 0 if prediction["AI"] > prediction["Real"] else 1
|
80 |
+
|
81 |
preds.append(pred_label)
|
82 |
labels.append(ground_truth_label)
|
83 |
images.append(img_name)
|
84 |
except Exception as e:
|
85 |
print(f"Error processing image {img_name}: {e}")
|
86 |
|
87 |
+
shutil.rmtree(temp_dir)
|
88 |
+
return evaluate_model(labels, preds)
|
89 |
|
|
|
90 |
def evaluate_model(labels, preds):
|
91 |
cm = confusion_matrix(labels, preds)
|
92 |
accuracy = accuracy_score(labels, preds)
|
|
|
95 |
fpr, tpr, _ = roc_curve(labels, preds)
|
96 |
roc_auc = auc(fpr, tpr)
|
97 |
|
98 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
|
99 |
+
|
100 |
+
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=LABELS).plot(cmap=plt.cm.Blues, ax=ax1)
|
101 |
+
ax1.set_title("Confusion Matrix")
|
102 |
+
|
103 |
+
ax2.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
|
104 |
+
ax2.plot([0, 1], [0, 1], color='gray', linestyle='--')
|
105 |
+
ax2.set_xlim([0.0, 1.0])
|
106 |
+
ax2.set_ylim([0.0, 1.05])
|
107 |
+
ax2.set_xlabel('False Positive Rate')
|
108 |
+
ax2.set_ylabel('True Positive Rate')
|
109 |
+
ax2.set_title('ROC Curve')
|
110 |
+
ax2.legend(loc="lower right")
|
111 |
+
|
112 |
+
plt.tight_layout()
|
113 |
+
|
114 |
+
return accuracy, roc_score, report, fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
|
|
116 |
def load_url(url):
|
117 |
try:
|
118 |
+
urllib.request.urlretrieve(url, "temp_image.png")
|
119 |
+
image = Image.open("temp_image.png")
|
120 |
+
message = "Image Loaded"
|
121 |
except Exception as e:
|
122 |
image = None
|
123 |
+
message = f"Image not Found<br>Error: {e}"
|
124 |
+
return image, message
|
125 |
+
|
126 |
+
detector = AIDetector()
|
127 |
+
|
128 |
+
def create_gradio_interface():
|
129 |
+
with gr.Blocks() as app:
|
130 |
+
gr.Markdown("""<center><h1>AI Image Detector<br><h4>(Test Demo - accuracy varies by model)</h4></h1></center>""")
|
131 |
+
|
132 |
+
with gr.Tabs():
|
133 |
+
with gr.Tab("Single Image Detection"):
|
134 |
+
with gr.Column():
|
135 |
+
inp = gr.Image(type='pil')
|
136 |
+
in_url = gr.Textbox(label="Image URL")
|
137 |
+
with gr.Row():
|
138 |
+
load_btn = gr.Button("Load URL")
|
139 |
+
btn = gr.Button("Detect AI")
|
140 |
+
message = gr.HTML()
|
141 |
+
|
142 |
+
with gr.Group():
|
143 |
+
with gr.Row():
|
144 |
+
final_prob = gr.Label(label="Final Probability")
|
145 |
+
with gr.Row():
|
146 |
+
with gr.Box():
|
147 |
+
gr.HTML(f"""<b>Testing on Model: <a href='https://huggingface.co/{MODEL_NAME}'>{MODEL_NAME}</a></b>""")
|
148 |
+
output_html = gr.HTML()
|
149 |
+
output_label = gr.Label(label="Output")
|
150 |
+
|
151 |
+
with gr.Tab("Batch Image Processing"):
|
152 |
+
zip_file = gr.File(label="Upload Zip (two folders: real, ai)")
|
153 |
+
batch_btn = gr.Button("Process Batch")
|
154 |
+
|
155 |
+
with gr.Group():
|
156 |
+
gr.Markdown(f"### Results for {MODEL_NAME}")
|
157 |
+
output_acc = gr.Label(label="Accuracy")
|
158 |
+
output_roc = gr.Label(label="ROC Score")
|
159 |
+
output_report = gr.Textbox(label="Classification Report", lines=10)
|
160 |
+
output_plots = gr.Plot(label="Confusion Matrix and ROC Curve")
|
161 |
+
|
162 |
+
load_btn.click(load_url, in_url, [inp, message])
|
163 |
+
btn.click(detector.clear_results, None, final_prob, show_progress=False)
|
164 |
+
btn.click(
|
165 |
+
lambda img: detector.predict(img),
|
166 |
+
inp,
|
167 |
+
[output_html, output_label]
|
168 |
+
).then(
|
169 |
+
lambda: detector.get_total_probability(),
|
170 |
+
None,
|
171 |
+
final_prob,
|
172 |
+
show_progress=False
|
173 |
+
)
|
174 |
+
|
175 |
+
batch_btn.click(
|
176 |
+
process_zip,
|
177 |
+
zip_file,
|
178 |
+
[output_acc, output_roc, output_report, output_plots]
|
179 |
+
)
|
180 |
+
|
181 |
+
return app
|
182 |
+
|
183 |
+
if __name__ == "__main__":
|
184 |
+
app = create_gradio_interface()
|
185 |
+
app.launch(show_api=False, max_threads=24)
|
|
|
|
|
|
|
|
|
|