feat: add simple prediction model with preprocessing and postprocessing functions
Browse files- app_test.py +31 -9
app_test.py
CHANGED
@@ -166,7 +166,38 @@ register_model_with_metadata(
|
|
166 |
display_name="ViT", contributor="temp", model_path=MODEL_PATHS["model_7"]
|
167 |
)
|
168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
def infer(image: Image.Image, model_id: str, confidence_threshold: float = 0.75) -> dict:
|
171 |
"""Predict using a specific model.
|
172 |
|
@@ -403,15 +434,6 @@ def ensemble_prediction(img, confidence_threshold, augment_methods, rotate_degre
|
|
403 |
|
404 |
return img_pil, cleaned_forensics_images, table_rows, json_results, consensus_html
|
405 |
|
406 |
-
def simple_prediction(img):
|
407 |
-
client = Client("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview")
|
408 |
-
result = client.predict(
|
409 |
-
input_image=handle_file(img),
|
410 |
-
api_name="/simple_predict"
|
411 |
-
)
|
412 |
-
return result
|
413 |
-
|
414 |
-
|
415 |
detection_model_eval_playground = gr.Interface(
|
416 |
fn=ensemble_prediction,
|
417 |
inputs=[
|
|
|
166 |
display_name="ViT", contributor="temp", model_path=MODEL_PATHS["model_7"]
|
167 |
)
|
168 |
|
169 |
+
def preprocess_simple_prediction(image):
|
170 |
+
# The simple_prediction function expects a PIL image (filepath is handled internally)
|
171 |
+
return image
|
172 |
+
|
173 |
+
def postprocess_simple_prediction(result, class_names):
|
174 |
+
scores = {name: 0.0 for name in class_names}
|
175 |
+
fake_prob = result.get("Fake Probability")
|
176 |
+
if fake_prob is not None:
|
177 |
+
# Assume class_names = ["AI", "REAL"]
|
178 |
+
scores["AI"] = float(fake_prob)
|
179 |
+
scores["REAL"] = 1.0 - float(fake_prob)
|
180 |
+
return scores
|
181 |
|
182 |
+
def simple_prediction(img):
|
183 |
+
client = Client("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview")
|
184 |
+
result = client.predict(
|
185 |
+
input_image=handle_file(img),
|
186 |
+
api_name="/simple_predict"
|
187 |
+
)
|
188 |
+
return result
|
189 |
+
|
190 |
+
|
191 |
+
register_model_with_metadata(
|
192 |
+
"simple_prediction",
|
193 |
+
simple_prediction,
|
194 |
+
preprocess_simple_prediction,
|
195 |
+
postprocess_simple_prediction,
|
196 |
+
["AI", "REAL"],
|
197 |
+
display_name="Community Forensics",
|
198 |
+
contributor="Jeongsoo Park",
|
199 |
+
model_path="aiwithoutborders-xyz/CommunityForensics-DeepfakeDet-ViT"
|
200 |
+
)
|
201 |
def infer(image: Image.Image, model_id: str, confidence_threshold: float = 0.75) -> dict:
|
202 |
"""Predict using a specific model.
|
203 |
|
|
|
434 |
|
435 |
return img_pil, cleaned_forensics_images, table_rows, json_results, consensus_html
|
436 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
437 |
detection_model_eval_playground = gr.Interface(
|
438 |
fn=ensemble_prediction,
|
439 |
inputs=[
|