major(feat): implement streaming ensemble prediction to enhance real-time model inference and update interface for live results
Browse files
app.py
CHANGED
@@ -240,23 +240,9 @@ def infer(image: Image.Image, model_id: str, confidence_threshold: float = 0.75)
|
|
240 |
"Label": f"Error: {str(e)}"
|
241 |
}
|
242 |
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
Args:
|
247 |
-
img (Image.Image): The input image to classify.
|
248 |
-
confidence_threshold (float): The confidence threshold for classification.
|
249 |
-
augment_methods (list): The augmentation methods to apply to the image.
|
250 |
-
rotate_degrees (int): The degrees to rotate the image.
|
251 |
-
noise_level (int): The noise level to add to the image.
|
252 |
-
sharpen_strength (int): The strength of the sharpening to apply to the image.
|
253 |
-
|
254 |
-
Raises:
|
255 |
-
ValueError: If the input image could not be converted to a PIL Image.
|
256 |
-
|
257 |
-
Returns:
|
258 |
-
tuple: A tuple containing the processed image, forensic images, model predictions, raw model results, and consensus.
|
259 |
-
"""
|
260 |
if not isinstance(img, Image.Image):
|
261 |
try:
|
262 |
img = Image.fromarray(img)
|
@@ -270,35 +256,44 @@ def ensemble_prediction(img, confidence_threshold, augment_methods, rotate_degre
|
|
270 |
health_agent = SystemHealthAgent()
|
271 |
context_agent = ContextualIntelligenceAgent()
|
272 |
anomaly_agent = ForensicAnomalyDetectionAgent()
|
273 |
-
|
274 |
health_agent.monitor_system_health()
|
275 |
|
276 |
if augment_methods:
|
277 |
img_pil, _ = augment_image(img, augment_methods, rotate_degrees, noise_level, sharpen_strength)
|
278 |
else:
|
279 |
img_pil = img
|
280 |
-
img_np_og = np.array(img)
|
281 |
|
282 |
model_predictions_raw = {}
|
283 |
confidence_scores = {}
|
284 |
results = []
|
|
|
285 |
|
|
|
286 |
for model_id in MODEL_REGISTRY:
|
287 |
model_start = time.time()
|
288 |
result = infer(img_pil, model_id, confidence_threshold)
|
289 |
model_end = time.time()
|
290 |
-
|
291 |
monitor_agent.monitor_prediction(
|
292 |
model_id,
|
293 |
result["Label"],
|
294 |
max(result.get("AI Score", 0.0), result.get("Real Score", 0.0)),
|
295 |
model_end - model_start
|
296 |
)
|
297 |
-
|
298 |
model_predictions_raw[model_id] = result
|
299 |
confidence_scores[model_id] = max(result.get("AI Score", 0.0), result.get("Real Score", 0.0))
|
300 |
results.append(result)
|
301 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
image_data_for_context = {
|
303 |
"width": img.width,
|
304 |
"height": img.height,
|
@@ -306,43 +301,29 @@ def ensemble_prediction(img, confidence_threshold, augment_methods, rotate_degre
|
|
306 |
}
|
307 |
detected_context_tags = context_agent.infer_context_tags(image_data_for_context, model_predictions_raw)
|
308 |
logger.info(f"Detected context tags: {detected_context_tags}")
|
309 |
-
|
310 |
adjusted_weights = weight_manager.adjust_weights(model_predictions_raw, confidence_scores, context_tags=detected_context_tags)
|
311 |
-
|
312 |
-
weighted_predictions = {
|
313 |
-
"AI": 0.0,
|
314 |
-
"REAL": 0.0,
|
315 |
-
"UNCERTAIN": 0.0
|
316 |
-
}
|
317 |
-
|
318 |
for model_id, prediction in model_predictions_raw.items():
|
319 |
prediction_label = prediction.get("Label")
|
320 |
if prediction_label in weighted_predictions:
|
321 |
weighted_predictions[prediction_label] += adjusted_weights[model_id]
|
322 |
else:
|
323 |
logger.warning(f"Unexpected prediction label '{prediction_label}' from model '{model_id}'. Skipping its weight in consensus.")
|
324 |
-
|
325 |
final_prediction_label = "UNCERTAIN"
|
326 |
if weighted_predictions["AI"] > weighted_predictions["REAL"] and weighted_predictions["AI"] > weighted_predictions["UNCERTAIN"]:
|
327 |
final_prediction_label = "AI"
|
328 |
elif weighted_predictions["REAL"] > weighted_predictions["AI"] and weighted_predictions["REAL"] > weighted_predictions["UNCERTAIN"]:
|
329 |
final_prediction_label = "REAL"
|
330 |
-
|
331 |
optimization_agent.analyze_performance(final_prediction_label, None)
|
332 |
-
|
333 |
gradient_image = gradient_processing(img_np_og)
|
334 |
gradient_image2 = gradient_processing(img_np_og, intensity=45, equalize=True)
|
335 |
-
|
336 |
minmax_image = minmax_process(img_np_og)
|
337 |
minmax_image2 = minmax_process(img_np_og, radius=6)
|
338 |
bitplane_image = bit_plane_extractor(img_pil)
|
339 |
-
|
340 |
ela1 = ELA(img_np_og, quality=75, scale=50, contrast=20, linear=False, grayscale=True)
|
341 |
ela2 = ELA(img_np_og, quality=75, scale=75, contrast=25, linear=False, grayscale=True)
|
342 |
ela3 = ELA(img_np_og, quality=75, scale=75, contrast=25, linear=False, grayscale=False)
|
343 |
-
|
344 |
forensics_images = [img_pil, ela1, ela2, ela3, gradient_image, gradient_image2, minmax_image, minmax_image2, bitplane_image]
|
345 |
-
|
346 |
forensic_output_descriptions = [
|
347 |
f"Original augmented image (PIL): {img_pil.width}x{img_pil.height}",
|
348 |
"ELA analysis (Pass 1): Grayscale error map, quality 75.",
|
@@ -356,21 +337,7 @@ def ensemble_prediction(img, confidence_threshold, augment_methods, rotate_degre
|
|
356 |
]
|
357 |
anomaly_detection_results = anomaly_agent.analyze_forensic_outputs(forensic_output_descriptions)
|
358 |
logger.info(f"Forensic anomaly detection: {anomaly_detection_results['summary']}")
|
359 |
-
|
360 |
-
table_rows = [[
|
361 |
-
r.get("Model", ""),
|
362 |
-
r.get("Contributor", ""),
|
363 |
-
round(r.get("AI Score", 0.0), 3) if r.get("AI Score") is not None else 0.0,
|
364 |
-
round(r.get("Real Score", 0.0), 3) if r.get("Real Score") is not None else 0.0,
|
365 |
-
r.get("Label", "Error")
|
366 |
-
] for r in results]
|
367 |
-
|
368 |
-
logger.info(f"Type of table_rows: {type(table_rows)}")
|
369 |
-
for i, row in enumerate(table_rows):
|
370 |
-
logger.info(f"Row {i} types: {[type(item) for item in row]}")
|
371 |
-
|
372 |
consensus_html = f"<b><span style='color:{'red' if final_prediction_label == 'AI' else ('green' if final_prediction_label == 'REAL' else 'orange')}'>{final_prediction_label}</span></b>"
|
373 |
-
|
374 |
inference_params = {
|
375 |
"confidence_threshold": confidence_threshold,
|
376 |
"augment_methods": augment_methods,
|
@@ -379,13 +346,11 @@ def ensemble_prediction(img, confidence_threshold, augment_methods, rotate_degre
|
|
379 |
"sharpen_strength": sharpen_strength,
|
380 |
"detected_context_tags": detected_context_tags
|
381 |
}
|
382 |
-
|
383 |
ensemble_output_data = {
|
384 |
"final_prediction_label": final_prediction_label,
|
385 |
"weighted_predictions": weighted_predictions,
|
386 |
"adjusted_weights": adjusted_weights
|
387 |
}
|
388 |
-
|
389 |
agent_monitoring_data_log = {
|
390 |
"ensemble_monitor": {
|
391 |
"alerts": monitor_agent.alerts,
|
@@ -403,7 +368,6 @@ def ensemble_prediction(img, confidence_threshold, augment_methods, rotate_degre
|
|
403 |
},
|
404 |
"forensic_anomaly_detection": anomaly_detection_results
|
405 |
}
|
406 |
-
|
407 |
log_inference_data(
|
408 |
original_image=img,
|
409 |
inference_params=inference_params,
|
@@ -413,7 +377,6 @@ def ensemble_prediction(img, confidence_threshold, augment_methods, rotate_degre
|
|
413 |
agent_monitoring_data=agent_monitoring_data_log,
|
414 |
human_feedback=None
|
415 |
)
|
416 |
-
|
417 |
cleaned_forensics_images = []
|
418 |
for f_img in forensics_images:
|
419 |
if isinstance(f_img, Image.Image):
|
@@ -425,22 +388,18 @@ def ensemble_prediction(img, confidence_threshold, augment_methods, rotate_degre
|
|
425 |
logger.warning(f"Could not convert numpy array to PIL Image for gallery: {e}")
|
426 |
else:
|
427 |
logger.warning(f"Unexpected type in forensic_images: {type(f_img)}. Skipping.")
|
428 |
-
|
429 |
logger.info(f"Cleaned forensic images types: {[type(img) for img in cleaned_forensics_images]}")
|
430 |
-
|
431 |
for i, res_dict in enumerate(results):
|
432 |
for key in ["AI Score", "Real Score"]:
|
433 |
value = res_dict.get(key)
|
434 |
if isinstance(value, np.float32):
|
435 |
res_dict[key] = float(value)
|
436 |
logger.info(f"Converted {key} for result {i} from numpy.float32 to float.")
|
437 |
-
|
438 |
json_results = json.dumps(results, cls=NumpyEncoder)
|
439 |
-
|
440 |
-
return img_pil, cleaned_forensics_images, table_rows, json_results, consensus_html
|
441 |
|
442 |
detection_model_eval_playground = gr.Interface(
|
443 |
-
fn=
|
444 |
inputs=[
|
445 |
gr.Image(label="Upload Image to Analyze", sources=['upload', 'webcam'], type='pil'),
|
446 |
gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Confidence Threshold"),
|
@@ -462,7 +421,8 @@ detection_model_eval_playground = gr.Interface(
|
|
462 |
],
|
463 |
title="Open Source Detection Models Found on the Hub",
|
464 |
description="Space will be upgraded shortly; inference on all 6 models should take about 1.2~ seconds once we're back on CUDA. The Community Forensics mother of all detection models is now available for inference, head to the middle tab above this. Lots of exciting things coming up, stay tuned!",
|
465 |
-
api_name="predict"
|
|
|
466 |
)
|
467 |
|
468 |
community_forensics_preview = gr.Interface(
|
|
|
240 |
"Label": f"Error: {str(e)}"
|
241 |
}
|
242 |
|
243 |
+
# --- Streaming Ensemble Prediction ---
|
244 |
+
def ensemble_prediction_stream(img, confidence_threshold, augment_methods, rotate_degrees, noise_level, sharpen_strength):
|
245 |
+
# Setup (same as before)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
if not isinstance(img, Image.Image):
|
247 |
try:
|
248 |
img = Image.fromarray(img)
|
|
|
256 |
health_agent = SystemHealthAgent()
|
257 |
context_agent = ContextualIntelligenceAgent()
|
258 |
anomaly_agent = ForensicAnomalyDetectionAgent()
|
|
|
259 |
health_agent.monitor_system_health()
|
260 |
|
261 |
if augment_methods:
|
262 |
img_pil, _ = augment_image(img, augment_methods, rotate_degrees, noise_level, sharpen_strength)
|
263 |
else:
|
264 |
img_pil = img
|
265 |
+
img_np_og = np.array(img)
|
266 |
|
267 |
model_predictions_raw = {}
|
268 |
confidence_scores = {}
|
269 |
results = []
|
270 |
+
table_rows = []
|
271 |
|
272 |
+
# Stream results as each model finishes
|
273 |
for model_id in MODEL_REGISTRY:
|
274 |
model_start = time.time()
|
275 |
result = infer(img_pil, model_id, confidence_threshold)
|
276 |
model_end = time.time()
|
|
|
277 |
monitor_agent.monitor_prediction(
|
278 |
model_id,
|
279 |
result["Label"],
|
280 |
max(result.get("AI Score", 0.0), result.get("Real Score", 0.0)),
|
281 |
model_end - model_start
|
282 |
)
|
|
|
283 |
model_predictions_raw[model_id] = result
|
284 |
confidence_scores[model_id] = max(result.get("AI Score", 0.0), result.get("Real Score", 0.0))
|
285 |
results.append(result)
|
286 |
+
table_rows.append([
|
287 |
+
result.get("Model", ""),
|
288 |
+
result.get("Contributor", ""),
|
289 |
+
round(result.get("AI Score", 0.0), 3) if result.get("AI Score") is not None else 0.0,
|
290 |
+
round(result.get("Real Score", 0.0), 3) if result.get("Real Score") is not None else 0.0,
|
291 |
+
result.get("Label", "Error")
|
292 |
+
])
|
293 |
+
# Yield partial results: only update the table, others are None
|
294 |
+
yield None, None, table_rows, None, None
|
295 |
+
|
296 |
+
# After all models, compute the rest as before
|
297 |
image_data_for_context = {
|
298 |
"width": img.width,
|
299 |
"height": img.height,
|
|
|
301 |
}
|
302 |
detected_context_tags = context_agent.infer_context_tags(image_data_for_context, model_predictions_raw)
|
303 |
logger.info(f"Detected context tags: {detected_context_tags}")
|
|
|
304 |
adjusted_weights = weight_manager.adjust_weights(model_predictions_raw, confidence_scores, context_tags=detected_context_tags)
|
305 |
+
weighted_predictions = {"AI": 0.0, "REAL": 0.0, "UNCERTAIN": 0.0}
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
for model_id, prediction in model_predictions_raw.items():
|
307 |
prediction_label = prediction.get("Label")
|
308 |
if prediction_label in weighted_predictions:
|
309 |
weighted_predictions[prediction_label] += adjusted_weights[model_id]
|
310 |
else:
|
311 |
logger.warning(f"Unexpected prediction label '{prediction_label}' from model '{model_id}'. Skipping its weight in consensus.")
|
|
|
312 |
final_prediction_label = "UNCERTAIN"
|
313 |
if weighted_predictions["AI"] > weighted_predictions["REAL"] and weighted_predictions["AI"] > weighted_predictions["UNCERTAIN"]:
|
314 |
final_prediction_label = "AI"
|
315 |
elif weighted_predictions["REAL"] > weighted_predictions["AI"] and weighted_predictions["REAL"] > weighted_predictions["UNCERTAIN"]:
|
316 |
final_prediction_label = "REAL"
|
|
|
317 |
optimization_agent.analyze_performance(final_prediction_label, None)
|
|
|
318 |
gradient_image = gradient_processing(img_np_og)
|
319 |
gradient_image2 = gradient_processing(img_np_og, intensity=45, equalize=True)
|
|
|
320 |
minmax_image = minmax_process(img_np_og)
|
321 |
minmax_image2 = minmax_process(img_np_og, radius=6)
|
322 |
bitplane_image = bit_plane_extractor(img_pil)
|
|
|
323 |
ela1 = ELA(img_np_og, quality=75, scale=50, contrast=20, linear=False, grayscale=True)
|
324 |
ela2 = ELA(img_np_og, quality=75, scale=75, contrast=25, linear=False, grayscale=True)
|
325 |
ela3 = ELA(img_np_og, quality=75, scale=75, contrast=25, linear=False, grayscale=False)
|
|
|
326 |
forensics_images = [img_pil, ela1, ela2, ela3, gradient_image, gradient_image2, minmax_image, minmax_image2, bitplane_image]
|
|
|
327 |
forensic_output_descriptions = [
|
328 |
f"Original augmented image (PIL): {img_pil.width}x{img_pil.height}",
|
329 |
"ELA analysis (Pass 1): Grayscale error map, quality 75.",
|
|
|
337 |
]
|
338 |
anomaly_detection_results = anomaly_agent.analyze_forensic_outputs(forensic_output_descriptions)
|
339 |
logger.info(f"Forensic anomaly detection: {anomaly_detection_results['summary']}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
340 |
consensus_html = f"<b><span style='color:{'red' if final_prediction_label == 'AI' else ('green' if final_prediction_label == 'REAL' else 'orange')}'>{final_prediction_label}</span></b>"
|
|
|
341 |
inference_params = {
|
342 |
"confidence_threshold": confidence_threshold,
|
343 |
"augment_methods": augment_methods,
|
|
|
346 |
"sharpen_strength": sharpen_strength,
|
347 |
"detected_context_tags": detected_context_tags
|
348 |
}
|
|
|
349 |
ensemble_output_data = {
|
350 |
"final_prediction_label": final_prediction_label,
|
351 |
"weighted_predictions": weighted_predictions,
|
352 |
"adjusted_weights": adjusted_weights
|
353 |
}
|
|
|
354 |
agent_monitoring_data_log = {
|
355 |
"ensemble_monitor": {
|
356 |
"alerts": monitor_agent.alerts,
|
|
|
368 |
},
|
369 |
"forensic_anomaly_detection": anomaly_detection_results
|
370 |
}
|
|
|
371 |
log_inference_data(
|
372 |
original_image=img,
|
373 |
inference_params=inference_params,
|
|
|
377 |
agent_monitoring_data=agent_monitoring_data_log,
|
378 |
human_feedback=None
|
379 |
)
|
|
|
380 |
cleaned_forensics_images = []
|
381 |
for f_img in forensics_images:
|
382 |
if isinstance(f_img, Image.Image):
|
|
|
388 |
logger.warning(f"Could not convert numpy array to PIL Image for gallery: {e}")
|
389 |
else:
|
390 |
logger.warning(f"Unexpected type in forensic_images: {type(f_img)}. Skipping.")
|
|
|
391 |
logger.info(f"Cleaned forensic images types: {[type(img) for img in cleaned_forensics_images]}")
|
|
|
392 |
for i, res_dict in enumerate(results):
|
393 |
for key in ["AI Score", "Real Score"]:
|
394 |
value = res_dict.get(key)
|
395 |
if isinstance(value, np.float32):
|
396 |
res_dict[key] = float(value)
|
397 |
logger.info(f"Converted {key} for result {i} from numpy.float32 to float.")
|
|
|
398 |
json_results = json.dumps(results, cls=NumpyEncoder)
|
399 |
+
yield img_pil, cleaned_forensics_images, table_rows, json_results, consensus_html
|
|
|
400 |
|
401 |
detection_model_eval_playground = gr.Interface(
|
402 |
+
fn=ensemble_prediction_stream,
|
403 |
inputs=[
|
404 |
gr.Image(label="Upload Image to Analyze", sources=['upload', 'webcam'], type='pil'),
|
405 |
gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Confidence Threshold"),
|
|
|
421 |
],
|
422 |
title="Open Source Detection Models Found on the Hub",
|
423 |
description="Space will be upgraded shortly; inference on all 6 models should take about 1.2~ seconds once we're back on CUDA. The Community Forensics mother of all detection models is now available for inference, head to the middle tab above this. Lots of exciting things coming up, stay tuned!",
|
424 |
+
api_name="predict",
|
425 |
+
live=True # Enable streaming
|
426 |
)
|
427 |
|
428 |
community_forensics_preview = gr.Interface(
|