LPX55 commited on
Commit
71cd7c0
·
1 Parent(s): c18d29a

feat: add simple prediction model with preprocessing and postprocessing functions

Browse files
Files changed (1) hide show
  1. app_test.py +31 -9
app_test.py CHANGED
@@ -166,7 +166,38 @@ register_model_with_metadata(
166
  display_name="ViT", contributor="temp", model_path=MODEL_PATHS["model_7"]
167
  )
168
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  def infer(image: Image.Image, model_id: str, confidence_threshold: float = 0.75) -> dict:
171
  """Predict using a specific model.
172
 
@@ -403,15 +434,6 @@ def ensemble_prediction(img, confidence_threshold, augment_methods, rotate_degre
403
 
404
  return img_pil, cleaned_forensics_images, table_rows, json_results, consensus_html
405
 
406
- def simple_prediction(img):
407
- client = Client("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview")
408
- result = client.predict(
409
- input_image=handle_file(img),
410
- api_name="/simple_predict"
411
- )
412
- return result
413
-
414
-
415
  detection_model_eval_playground = gr.Interface(
416
  fn=ensemble_prediction,
417
  inputs=[
 
166
  display_name="ViT", contributor="temp", model_path=MODEL_PATHS["model_7"]
167
  )
168
 
169
+ def preprocess_simple_prediction(image):
170
+ # The simple_prediction function expects a PIL image (filepath is handled internally)
171
+ return image
172
+
173
+ def postprocess_simple_prediction(result, class_names):
174
+ scores = {name: 0.0 for name in class_names}
175
+ fake_prob = result.get("Fake Probability")
176
+ if fake_prob is not None:
177
+ # Assume class_names = ["AI", "REAL"]
178
+ scores["AI"] = float(fake_prob)
179
+ scores["REAL"] = 1.0 - float(fake_prob)
180
+ return scores
181
 
182
+ def simple_prediction(img):
183
+ client = Client("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview")
184
+ result = client.predict(
185
+ input_image=handle_file(img),
186
+ api_name="/simple_predict"
187
+ )
188
+ return result
189
+
190
+
191
+ register_model_with_metadata(
192
+ "simple_prediction",
193
+ simple_prediction,
194
+ preprocess_simple_prediction,
195
+ postprocess_simple_prediction,
196
+ ["AI", "REAL"],
197
+ display_name="Community Forensics",
198
+ contributor="Jeongsoo Park",
199
+ model_path="aiwithoutborders-xyz/CommunityForensics-DeepfakeDet-ViT"
200
+ )
201
  def infer(image: Image.Image, model_id: str, confidence_threshold: float = 0.75) -> dict:
202
  """Predict using a specific model.
203
 
 
434
 
435
  return img_pil, cleaned_forensics_images, table_rows, json_results, consensus_html
436
 
 
 
 
 
 
 
 
 
 
437
  detection_model_eval_playground = gr.Interface(
438
  fn=ensemble_prediction,
439
  inputs=[