cmckinle commited on
Commit
773268b
·
verified ·
1 Parent(s): f88c19d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -165
app.py CHANGED
@@ -5,78 +5,69 @@ import os
5
  import zipfile
6
  import shutil
7
  import matplotlib.pyplot as plt
8
- from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report, roc_curve, auc
9
  from PIL import Image
10
  import uuid
11
  import tempfile
12
  import pandas as pd
13
- from numpy import exp
14
  import numpy as np
15
- from sklearn.metrics import ConfusionMatrixDisplay
16
  import urllib.request
17
 
18
- # Define model
19
- model = "cmckinle/sdxl-flux-detector"
20
- pipe = pipeline("image-classification", model)
21
-
22
- fin_sum = []
23
- uid = uuid.uuid4()
24
-
25
- # Softmax function
26
- def softmax(vector):
27
- e = exp(vector - vector.max()) # for numerical stability
28
- return e / e.sum()
29
-
30
- # Single image classification function
31
- def image_classifier(image):
32
- labels = ["AI", "Real"]
33
- outputs = pipe(image)
34
- results = {}
35
- for idx, result in enumerate(outputs):
36
- results[labels[idx]] = float(outputs[idx]['score'])
37
- fin_sum.append(results)
38
- return results
39
-
40
- def aiornot(image):
41
- labels = ["AI", "Real"]
42
- feature_extractor = AutoFeatureExtractor.from_pretrained(model)
43
- model_cls = AutoModelForImageClassification.from_pretrained(model)
44
- input = feature_extractor(image, return_tensors="pt")
45
- with torch.no_grad():
46
- outputs = model_cls(**input)
47
- logits = outputs.logits
48
- probability = softmax(logits)
49
- px = pd.DataFrame(probability.numpy())
50
- prediction = logits.argmax(-1).item()
51
- label = labels[prediction]
52
-
53
- html_out = f"""
54
- <h1>This image is likely: {label}</h1><br><h3>
55
- Probabilities:<br>
56
- Real: {float(px[1][0]):.4f}<br>
57
- AI: {float(px[0][0]):.4f}"""
58
-
59
- results = {
60
- "Real": float(px[1][0]),
61
- "AI": float(px[0][0])
62
- }
63
- fin_sum.append(results)
64
- return gr.HTML.update(html_out), results
65
-
66
- # Function to extract images from zip
67
- def extract_zip(zip_file):
68
  temp_dir = tempfile.mkdtemp()
69
- with zipfile.ZipFile(zip_file, 'r') as z:
70
  z.extractall(temp_dir)
71
- return temp_dir
72
-
73
- # Function to classify images in a folder
74
- def classify_images(image_dir):
75
- images = []
76
- labels = []
77
- preds = []
78
  for folder_name, ground_truth_label in [('real', 1), ('ai', 0)]:
79
- folder_path = os.path.join(image_dir, folder_name)
80
  if not os.path.exists(folder_path):
81
  print(f"Folder not found: {folder_path}")
82
  continue
@@ -84,19 +75,18 @@ def classify_images(image_dir):
84
  img_path = os.path.join(folder_path, img_name)
85
  try:
86
  img = Image.open(img_path).convert("RGB")
87
- pred = pipe(img)
88
- pred_label = 0 if pred[0]['label'] == 'AI' else 1
89
-
90
  preds.append(pred_label)
91
  labels.append(ground_truth_label)
92
  images.append(img_name)
93
  except Exception as e:
94
  print(f"Error processing image {img_name}: {e}")
95
 
96
- print(f"Processed {len(images)} images")
97
- return labels, preds, images
98
 
99
- # Function to generate evaluation metrics
100
  def evaluate_model(labels, preds):
101
  cm = confusion_matrix(labels, preds)
102
  accuracy = accuracy_score(labels, preds)
@@ -105,105 +95,91 @@ def evaluate_model(labels, preds):
105
  fpr, tpr, _ = roc_curve(labels, preds)
106
  roc_auc = auc(fpr, tpr)
107
 
108
- fig, ax = plt.subplots()
109
- disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["AI", "Real"])
110
- disp.plot(cmap=plt.cm.Blues, ax=ax)
111
- plt.close(fig)
112
-
113
- fig_roc, ax_roc = plt.subplots()
114
- ax_roc.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
115
- ax_roc.plot([0, 1], [0, 1], color='gray', linestyle='--')
116
- ax_roc.set_xlim([0.0, 1.0])
117
- ax_roc.set_ylim([0.0, 1.05])
118
- ax_roc.set_xlabel('False Positive Rate')
119
- ax_roc.set_ylabel('True Positive Rate')
120
- ax_roc.set_title('Receiver Operating Characteristic (ROC) Curve')
121
- ax_roc.legend(loc="lower right")
122
- plt.close(fig_roc)
123
-
124
- return accuracy, roc_score, report, fig, fig_roc
125
-
126
- # Batch processing
127
- def process_zip(zip_file):
128
- extracted_dir = extract_zip(zip_file.name)
129
- labels, preds, images = classify_images(extracted_dir)
130
- accuracy, roc_score, report, cm_fig, roc_fig = evaluate_model(labels, preds)
131
- shutil.rmtree(extracted_dir) # Clean up extracted files
132
- return accuracy, roc_score, report, cm_fig, roc_fig
133
 
134
- # Single image section
135
  def load_url(url):
136
  try:
137
- urllib.request.urlretrieve(f'{url}', f"{uid}tmp_im.png")
138
- image = Image.open(f"{uid}tmp_im.png")
139
- mes = "Image Loaded"
140
  except Exception as e:
141
  image = None
142
- mes = f"Image not Found<br>Error: {e}"
143
- return image, mes
144
-
145
- def tot_prob():
146
- try:
147
- fin_out = sum([result["Real"] for result in fin_sum]) / len(fin_sum)
148
- fin_sub = 1 - fin_out
149
- out = {
150
- "Real": f"{fin_out:.4f}",
151
- "AI": f"{fin_sub:.4f}"
152
- }
153
- return out
154
- except Exception as e:
155
- print(e)
156
- return None
157
-
158
- def fin_clear():
159
- fin_sum.clear()
160
- return None
161
-
162
- # Set up Gradio app
163
- with gr.Blocks() as app:
164
- gr.Markdown("""<center><h1>AI Image Detector<br><h4>(Test Demo - accuracy varies by model)</h4></h1></center>""")
165
-
166
- with gr.Tabs():
167
- # Tab for single image detection
168
- with gr.Tab("Single Image Detection"):
169
- with gr.Column():
170
- inp = gr.Image(type='pil')
171
- in_url = gr.Textbox(label="Image URL")
172
- with gr.Row():
173
- load_btn = gr.Button("Load URL")
174
- btn = gr.Button("Detect AI")
175
- mes = gr.HTML("""""")
176
-
177
- with gr.Group():
178
- with gr.Row():
179
- fin = gr.Label(label="Final Probability")
180
- with gr.Row():
181
- with gr.Box():
182
- gr.HTML(f"""<b>Testing on Model: <a href='https://huggingface.co/{model}'>{model}</a></b>""")
183
- outp = gr.HTML("""""")
184
- n_out = gr.Label(label="Output")
185
-
186
- btn.click(fin_clear, None, fin, show_progress=False)
187
- load_btn.click(load_url, in_url, [inp, mes])
188
-
189
- btn.click(aiornot, [inp], [outp, n_out]).then(
190
- tot_prob, None, fin, show_progress=False)
191
-
192
- # Tab for batch processing
193
- with gr.Tab("Batch Image Processing"):
194
- zip_file = gr.File(label="Upload Zip (two folders: real, ai)")
195
- batch_btn = gr.Button("Process Batch")
196
-
197
- with gr.Group():
198
- gr.Markdown(f"### Results for {model}")
199
- output_acc = gr.Label(label="Accuracy")
200
- output_roc = gr.Label(label="ROC Score")
201
- output_report = gr.Textbox(label="Classification Report", lines=10)
202
- output_cm = gr.Plot(label="Confusion Matrix")
203
- output_roc_plot = gr.Plot(label="ROC Curve")
204
-
205
- # Connect batch processing
206
- batch_btn.click(process_zip, zip_file,
207
- [output_acc, output_roc, output_report, output_cm, output_roc_plot])
208
-
209
- app.launch(show_api=False, max_threads=24)
 
5
  import zipfile
6
  import shutil
7
  import matplotlib.pyplot as plt
8
+ from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report, roc_curve, auc, ConfusionMatrixDisplay
9
  from PIL import Image
10
  import uuid
11
  import tempfile
12
  import pandas as pd
 
13
  import numpy as np
 
14
  import urllib.request
15
 
16
+ MODEL_NAME = "cmckinle/sdxl-flux-detector"
17
+ LABELS = ["AI", "Real"]
18
+
19
+ class AIDetector:
20
+ def __init__(self):
21
+ self.pipe = pipeline("image-classification", MODEL_NAME)
22
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
23
+ self.model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
24
+ self.results = []
25
+
26
+ @staticmethod
27
+ def softmax(vector):
28
+ e = np.exp(vector - np.max(vector))
29
+ return e / e.sum()
30
+
31
+ def classify_image(self, image):
32
+ outputs = self.pipe(image)
33
+ results = {label: float(output['score']) for label, output in zip(LABELS, outputs)}
34
+ self.results.append(results)
35
+ return results
36
+
37
+ def predict(self, image):
38
+ inputs = self.feature_extractor(image, return_tensors="pt")
39
+ with torch.no_grad():
40
+ outputs = self.model(**inputs)
41
+ logits = outputs.logits
42
+ probabilities = self.softmax(logits.numpy())
43
+
44
+ prediction = logits.argmax(-1).item()
45
+ label = LABELS[prediction]
46
+
47
+ results = {label: float(prob) for label, prob in zip(LABELS, probabilities[0])}
48
+ self.results.append(results)
49
+
50
+ return label, results
51
+
52
+ def get_total_probability(self):
53
+ if not self.results:
54
+ return None
55
+ avg_real_prob = sum(result["Real"] for result in self.results) / len(self.results)
56
+ return {"Real": f"{avg_real_prob:.4f}", "AI": f"{1 - avg_real_prob:.4f}"}
57
+
58
+ def clear_results(self):
59
+ self.results.clear()
60
+
61
+ def process_zip(zip_file):
 
 
 
 
62
  temp_dir = tempfile.mkdtemp()
63
+ with zipfile.ZipFile(zip_file.name, 'r') as z:
64
  z.extractall(temp_dir)
65
+
66
+ labels, preds, images = [], [], []
67
+ detector = AIDetector()
68
+
 
 
 
69
  for folder_name, ground_truth_label in [('real', 1), ('ai', 0)]:
70
+ folder_path = os.path.join(temp_dir, folder_name)
71
  if not os.path.exists(folder_path):
72
  print(f"Folder not found: {folder_path}")
73
  continue
 
75
  img_path = os.path.join(folder_path, img_name)
76
  try:
77
  img = Image.open(img_path).convert("RGB")
78
+ _, prediction = detector.predict(img)
79
+ pred_label = 0 if prediction["AI"] > prediction["Real"] else 1
80
+
81
  preds.append(pred_label)
82
  labels.append(ground_truth_label)
83
  images.append(img_name)
84
  except Exception as e:
85
  print(f"Error processing image {img_name}: {e}")
86
 
87
+ shutil.rmtree(temp_dir)
88
+ return evaluate_model(labels, preds)
89
 
 
90
  def evaluate_model(labels, preds):
91
  cm = confusion_matrix(labels, preds)
92
  accuracy = accuracy_score(labels, preds)
 
95
  fpr, tpr, _ = roc_curve(labels, preds)
96
  roc_auc = auc(fpr, tpr)
97
 
98
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
99
+
100
+ ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=LABELS).plot(cmap=plt.cm.Blues, ax=ax1)
101
+ ax1.set_title("Confusion Matrix")
102
+
103
+ ax2.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
104
+ ax2.plot([0, 1], [0, 1], color='gray', linestyle='--')
105
+ ax2.set_xlim([0.0, 1.0])
106
+ ax2.set_ylim([0.0, 1.05])
107
+ ax2.set_xlabel('False Positive Rate')
108
+ ax2.set_ylabel('True Positive Rate')
109
+ ax2.set_title('ROC Curve')
110
+ ax2.legend(loc="lower right")
111
+
112
+ plt.tight_layout()
113
+
114
+ return accuracy, roc_score, report, fig
 
 
 
 
 
 
 
 
115
 
 
116
  def load_url(url):
117
  try:
118
+ urllib.request.urlretrieve(url, "temp_image.png")
119
+ image = Image.open("temp_image.png")
120
+ message = "Image Loaded"
121
  except Exception as e:
122
  image = None
123
+ message = f"Image not Found<br>Error: {e}"
124
+ return image, message
125
+
126
+ detector = AIDetector()
127
+
128
+ def create_gradio_interface():
129
+ with gr.Blocks() as app:
130
+ gr.Markdown("""<center><h1>AI Image Detector<br><h4>(Test Demo - accuracy varies by model)</h4></h1></center>""")
131
+
132
+ with gr.Tabs():
133
+ with gr.Tab("Single Image Detection"):
134
+ with gr.Column():
135
+ inp = gr.Image(type='pil')
136
+ in_url = gr.Textbox(label="Image URL")
137
+ with gr.Row():
138
+ load_btn = gr.Button("Load URL")
139
+ btn = gr.Button("Detect AI")
140
+ message = gr.HTML()
141
+
142
+ with gr.Group():
143
+ with gr.Row():
144
+ final_prob = gr.Label(label="Final Probability")
145
+ with gr.Row():
146
+ with gr.Box():
147
+ gr.HTML(f"""<b>Testing on Model: <a href='https://huggingface.co/{MODEL_NAME}'>{MODEL_NAME}</a></b>""")
148
+ output_html = gr.HTML()
149
+ output_label = gr.Label(label="Output")
150
+
151
+ with gr.Tab("Batch Image Processing"):
152
+ zip_file = gr.File(label="Upload Zip (two folders: real, ai)")
153
+ batch_btn = gr.Button("Process Batch")
154
+
155
+ with gr.Group():
156
+ gr.Markdown(f"### Results for {MODEL_NAME}")
157
+ output_acc = gr.Label(label="Accuracy")
158
+ output_roc = gr.Label(label="ROC Score")
159
+ output_report = gr.Textbox(label="Classification Report", lines=10)
160
+ output_plots = gr.Plot(label="Confusion Matrix and ROC Curve")
161
+
162
+ load_btn.click(load_url, in_url, [inp, message])
163
+ btn.click(detector.clear_results, None, final_prob, show_progress=False)
164
+ btn.click(
165
+ lambda img: detector.predict(img),
166
+ inp,
167
+ [output_html, output_label]
168
+ ).then(
169
+ lambda: detector.get_total_probability(),
170
+ None,
171
+ final_prob,
172
+ show_progress=False
173
+ )
174
+
175
+ batch_btn.click(
176
+ process_zip,
177
+ zip_file,
178
+ [output_acc, output_roc, output_report, output_plots]
179
+ )
180
+
181
+ return app
182
+
183
+ if __name__ == "__main__":
184
+ app = create_gradio_interface()
185
+ app.launch(show_api=False, max_threads=24)