cmckinle commited on
Commit
cf9edec
·
verified ·
1 Parent(s): a8e840f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -103
app.py CHANGED
@@ -91,13 +91,32 @@ def evaluate_model(labels, preds):
91
 
92
  return accuracy, roc_score, report, fig, fig_roc
93
 
94
- # Gradio function for batch image processing
95
  def process_zip(zip_file):
96
  extracted_dir = extract_zip(zip_file.name)
97
- labels, preds, images = classify_images(extracted_dir, pipe0) # You can switch to pipe1 or pipe2
98
- accuracy, roc_score, report, cm_fig, roc_fig = evaluate_model(labels, preds)
 
 
 
 
 
 
 
 
 
 
 
 
99
  shutil.rmtree(extracted_dir) # Clean up extracted files
100
- return accuracy, roc_score, report, cm_fig, roc_fig
 
 
 
 
 
 
 
101
 
102
  # Single image classification functions
103
  def image_classifier0(image):
@@ -127,87 +146,6 @@ def image_classifier2(image):
127
  fin_sum.append(results)
128
  return results
129
 
130
- def aiornot0(image):
131
- labels = ["AI", "Real"]
132
- mod = models[0]
133
- feature_extractor0 = AutoFeatureExtractor.from_pretrained(mod)
134
- model0 = AutoModelForImageClassification.from_pretrained(mod)
135
- input = feature_extractor0(image, return_tensors="pt")
136
- with torch.no_grad():
137
- outputs = model0(**input)
138
- logits = outputs.logits
139
- probability = softmax(logits) # Apply softmax on logits
140
- px = pd.DataFrame(probability.numpy())
141
- prediction = logits.argmax(-1).item()
142
- label = labels[prediction]
143
-
144
- html_out = f"""
145
- <h1>This image is likely: {label}</h1><br><h3>
146
- Probabilities:<br>
147
- Real: {float(px[1][0]):.4f}<br>
148
- AI: {float(px[0][0]):.4f}"""
149
-
150
- results = {
151
- "Real": float(px[1][0]),
152
- "AI": float(px[0][0])
153
- }
154
- fin_sum.append(results)
155
- return gr.HTML.update(html_out), results
156
-
157
- def aiornot1(image):
158
- labels = ["AI", "Real"]
159
- mod = models[1]
160
- feature_extractor1 = AutoFeatureExtractor.from_pretrained(mod)
161
- model1 = AutoModelForImageClassification.from_pretrained(mod)
162
- input = feature_extractor1(image, return_tensors="pt")
163
- with torch.no_grad():
164
- outputs = model1(**input)
165
- logits = outputs.logits
166
- probability = softmax(logits) # Apply softmax on logits
167
- px = pd.DataFrame(probability.numpy())
168
- prediction = logits.argmax(-1).item()
169
- label = labels[prediction]
170
-
171
- html_out = f"""
172
- <h1>This image is likely: {label}</h1><br><h3>
173
- Probabilities:<br>
174
- Real: {float(px[1][0]):.4f}<br>
175
- AI: {float(px[0][0]):.4f}"""
176
-
177
- results = {
178
- "Real": float(px[1][0]),
179
- "AI": float(px[0][0])
180
- }
181
- fin_sum.append(results)
182
- return gr.HTML.update(html_out), results
183
-
184
- def aiornot2(image):
185
- labels = ["AI", "Real"]
186
- mod = models[2]
187
- feature_extractor2 = AutoFeatureExtractor.from_pretrained(mod)
188
- model2 = AutoModelForImageClassification.from_pretrained(mod)
189
- input = feature_extractor2(image, return_tensors="pt")
190
- with torch.no_grad():
191
- outputs = model2(**input)
192
- logits = outputs.logits
193
- probability = softmax(logits) # Apply softmax on logits
194
- px = pd.DataFrame(probability.numpy())
195
- prediction = logits.argmax(-1).item()
196
- label = labels[prediction]
197
-
198
- html_out = f"""
199
- <h1>This image is likely: {label}</h1><br><h3>
200
- Probabilities:<br>
201
- Real: {float(px[1][0]):.4f}<br>
202
- AI: {float(px[0][0]):.4f}"""
203
-
204
- results = {
205
- "Real": float(px[1][0]),
206
- "AI": float(px[0][0])
207
- }
208
- fin_sum.append(results)
209
- return gr.HTML.update(html_out), results
210
-
211
  def load_url(url):
212
  try:
213
  urllib.request.urlretrieve(f'{url}', f"{uid}tmp_im.png")
@@ -235,12 +173,6 @@ def fin_clear():
235
  fin_sum.clear()
236
  return None
237
 
238
- def upd(image):
239
- rand_im = uuid.uuid4()
240
- image.save(f"{rand_im}-vid_tmp_proc.png")
241
- out = Image.open(f"{rand_im}-vid_tmp_proc.png")
242
- return out
243
-
244
  # Set up Gradio app
245
  with gr.Blocks() as app:
246
  gr.Markdown("""<center><h1>AI Image Detector<br><h4>(Test Demo - accuracy varies by model)</h4></h1></center>""")
@@ -269,11 +201,6 @@ with gr.Blocks() as app:
269
  btn.click(fin_clear, None, fin, show_progress=False)
270
  load_btn.click(load_url, in_url, [inp, mes])
271
 
272
- btn.click(aiornot0, [inp], [outp0, n_out0]).then(
273
- aiornot1, [inp], [outp1, n_out1]).then(
274
- aiornot2, [inp], [outp2, n_out2]).then(
275
- tot_prob, None, fin, show_progress=False)
276
-
277
  btn.click(image_classifier0, [inp], [n_out0]).then(
278
  image_classifier1, [inp], [n_out1]).then(
279
  image_classifier2, [inp], [n_out2]).then(
@@ -282,15 +209,21 @@ with gr.Blocks() as app:
282
  # Tab for batch processing
283
  with gr.Tab("Batch Image Processing"):
284
  zip_file = gr.File(label="Upload Zip (two folders: real, ai)")
285
- output_acc = gr.Label(label="Accuracy")
286
- output_roc = gr.Label(label="ROC Score")
287
- output_report = gr.Textbox(label="Classification Report", lines=10)
288
- output_cm = gr.Plot(label="Confusion Matrix")
289
- output_roc_plot = gr.Plot(label="ROC Curve")
 
 
 
 
 
290
 
291
  batch_btn = gr.Button("Process Batch")
292
 
293
  # Connect batch processing
294
- batch_btn.click(process_zip, zip_file, [output_acc, output_roc, output_report, output_cm, output_roc_plot])
 
295
 
296
- app.launch(show_api=False, max_threads=24)
 
91
 
92
  return accuracy, roc_score, report, fig, fig_roc
93
 
94
+ # Gradio function for batch image processing with all models
95
  def process_zip(zip_file):
96
  extracted_dir = extract_zip(zip_file.name)
97
+
98
+ # Run classification for each model
99
+ results = {}
100
+ for idx, pipe in enumerate([pipe0, pipe1, pipe2]):
101
+ labels, preds, images = classify_images(extracted_dir, pipe)
102
+ accuracy, roc_score, report, cm_fig, roc_fig = evaluate_model(labels, preds)
103
+
104
+ # Store results for each model
105
+ results[f'Model_{idx}_accuracy'] = accuracy
106
+ results[f'Model_{idx}_roc_score'] = roc_score
107
+ results[f'Model_{idx}_report'] = report
108
+ results[f'Model_{idx}_cm_fig'] = cm_fig
109
+ results[f'Model_{idx}_roc_fig'] = roc_fig
110
+
111
  shutil.rmtree(extracted_dir) # Clean up extracted files
112
+
113
+ # Return results for all three models
114
+ return (results['Model_0_accuracy'], results['Model_0_roc_score'], results['Model_0_report'],
115
+ results['Model_0_cm_fig'], results['Model_0_roc_fig'],
116
+ results['Model_1_accuracy'], results['Model_1_roc_score'], results['Model_1_report'],
117
+ results['Model_1_cm_fig'], results['Model_1_roc_fig'],
118
+ results['Model_2_accuracy'], results['Model_2_roc_score'], results['Model_2_report'],
119
+ results['Model_2_cm_fig'], results['Model_2_roc_fig'])
120
 
121
  # Single image classification functions
122
  def image_classifier0(image):
 
146
  fin_sum.append(results)
147
  return results
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  def load_url(url):
150
  try:
151
  urllib.request.urlretrieve(f'{url}', f"{uid}tmp_im.png")
 
173
  fin_sum.clear()
174
  return None
175
 
 
 
 
 
 
 
176
  # Set up Gradio app
177
  with gr.Blocks() as app:
178
  gr.Markdown("""<center><h1>AI Image Detector<br><h4>(Test Demo - accuracy varies by model)</h4></h1></center>""")
 
201
  btn.click(fin_clear, None, fin, show_progress=False)
202
  load_btn.click(load_url, in_url, [inp, mes])
203
 
 
 
 
 
 
204
  btn.click(image_classifier0, [inp], [n_out0]).then(
205
  image_classifier1, [inp], [n_out1]).then(
206
  image_classifier2, [inp], [n_out2]).then(
 
209
  # Tab for batch processing
210
  with gr.Tab("Batch Image Processing"):
211
  zip_file = gr.File(label="Upload Zip (two folders: real, ai)")
212
+
213
+ # Outputs for all three models
214
+ for i in range(3):
215
+ with gr.Group():
216
+ gr.Markdown(f"### Results for Model {i}")
217
+ output_acc = gr.Label(label=f"Model {i} Accuracy")
218
+ output_roc = gr.Label(label=f"Model {i} ROC Score")
219
+ output_report = gr.Textbox(label=f"Model {i} Classification Report", lines=10)
220
+ output_cm = gr.Plot(label=f"Model {i} Confusion Matrix")
221
+ output_roc_plot = gr.Plot(label=f"Model {i} ROC Curve")
222
 
223
  batch_btn = gr.Button("Process Batch")
224
 
225
  # Connect batch processing
226
+ batch_btn.click(process_zip, zip_file,
227
+ [output_acc, output_roc, output_report, output_cm, output_roc_plot] * 3) # For all 3 models
228
 
229
+ app.launch(show_api=False, max_threads=24)