LPX55 commited on
Commit
c56a0f7
·
1 Parent(s): be96dd0

major(feat): implement streaming ensemble prediction to enhance real-time model inference and update interface for live results

Browse files
Files changed (1) hide show
  1. app.py +22 -62
app.py CHANGED
@@ -240,23 +240,9 @@ def infer(image: Image.Image, model_id: str, confidence_threshold: float = 0.75)
240
  "Label": f"Error: {str(e)}"
241
  }
242
 
243
- def ensemble_prediction(img, confidence_threshold, augment_methods, rotate_degrees, noise_level, sharpen_strength):
244
- """Full ensemble prediction pipeline.
245
-
246
- Args:
247
- img (Image.Image): The input image to classify.
248
- confidence_threshold (float): The confidence threshold for classification.
249
- augment_methods (list): The augmentation methods to apply to the image.
250
- rotate_degrees (int): The degrees to rotate the image.
251
- noise_level (int): The noise level to add to the image.
252
- sharpen_strength (int): The strength of the sharpening to apply to the image.
253
-
254
- Raises:
255
- ValueError: If the input image could not be converted to a PIL Image.
256
-
257
- Returns:
258
- tuple: A tuple containing the processed image, forensic images, model predictions, raw model results, and consensus.
259
- """
260
  if not isinstance(img, Image.Image):
261
  try:
262
  img = Image.fromarray(img)
@@ -270,35 +256,44 @@ def ensemble_prediction(img, confidence_threshold, augment_methods, rotate_degre
270
  health_agent = SystemHealthAgent()
271
  context_agent = ContextualIntelligenceAgent()
272
  anomaly_agent = ForensicAnomalyDetectionAgent()
273
-
274
  health_agent.monitor_system_health()
275
 
276
  if augment_methods:
277
  img_pil, _ = augment_image(img, augment_methods, rotate_degrees, noise_level, sharpen_strength)
278
  else:
279
  img_pil = img
280
- img_np_og = np.array(img) # Convert PIL Image to NumPy array
281
 
282
  model_predictions_raw = {}
283
  confidence_scores = {}
284
  results = []
 
285
 
 
286
  for model_id in MODEL_REGISTRY:
287
  model_start = time.time()
288
  result = infer(img_pil, model_id, confidence_threshold)
289
  model_end = time.time()
290
-
291
  monitor_agent.monitor_prediction(
292
  model_id,
293
  result["Label"],
294
  max(result.get("AI Score", 0.0), result.get("Real Score", 0.0)),
295
  model_end - model_start
296
  )
297
-
298
  model_predictions_raw[model_id] = result
299
  confidence_scores[model_id] = max(result.get("AI Score", 0.0), result.get("Real Score", 0.0))
300
  results.append(result)
301
-
 
 
 
 
 
 
 
 
 
 
302
  image_data_for_context = {
303
  "width": img.width,
304
  "height": img.height,
@@ -306,43 +301,29 @@ def ensemble_prediction(img, confidence_threshold, augment_methods, rotate_degre
306
  }
307
  detected_context_tags = context_agent.infer_context_tags(image_data_for_context, model_predictions_raw)
308
  logger.info(f"Detected context tags: {detected_context_tags}")
309
-
310
  adjusted_weights = weight_manager.adjust_weights(model_predictions_raw, confidence_scores, context_tags=detected_context_tags)
311
-
312
- weighted_predictions = {
313
- "AI": 0.0,
314
- "REAL": 0.0,
315
- "UNCERTAIN": 0.0
316
- }
317
-
318
  for model_id, prediction in model_predictions_raw.items():
319
  prediction_label = prediction.get("Label")
320
  if prediction_label in weighted_predictions:
321
  weighted_predictions[prediction_label] += adjusted_weights[model_id]
322
  else:
323
  logger.warning(f"Unexpected prediction label '{prediction_label}' from model '{model_id}'. Skipping its weight in consensus.")
324
-
325
  final_prediction_label = "UNCERTAIN"
326
  if weighted_predictions["AI"] > weighted_predictions["REAL"] and weighted_predictions["AI"] > weighted_predictions["UNCERTAIN"]:
327
  final_prediction_label = "AI"
328
  elif weighted_predictions["REAL"] > weighted_predictions["AI"] and weighted_predictions["REAL"] > weighted_predictions["UNCERTAIN"]:
329
  final_prediction_label = "REAL"
330
-
331
  optimization_agent.analyze_performance(final_prediction_label, None)
332
-
333
  gradient_image = gradient_processing(img_np_og)
334
  gradient_image2 = gradient_processing(img_np_og, intensity=45, equalize=True)
335
-
336
  minmax_image = minmax_process(img_np_og)
337
  minmax_image2 = minmax_process(img_np_og, radius=6)
338
  bitplane_image = bit_plane_extractor(img_pil)
339
-
340
  ela1 = ELA(img_np_og, quality=75, scale=50, contrast=20, linear=False, grayscale=True)
341
  ela2 = ELA(img_np_og, quality=75, scale=75, contrast=25, linear=False, grayscale=True)
342
  ela3 = ELA(img_np_og, quality=75, scale=75, contrast=25, linear=False, grayscale=False)
343
-
344
  forensics_images = [img_pil, ela1, ela2, ela3, gradient_image, gradient_image2, minmax_image, minmax_image2, bitplane_image]
345
-
346
  forensic_output_descriptions = [
347
  f"Original augmented image (PIL): {img_pil.width}x{img_pil.height}",
348
  "ELA analysis (Pass 1): Grayscale error map, quality 75.",
@@ -356,21 +337,7 @@ def ensemble_prediction(img, confidence_threshold, augment_methods, rotate_degre
356
  ]
357
  anomaly_detection_results = anomaly_agent.analyze_forensic_outputs(forensic_output_descriptions)
358
  logger.info(f"Forensic anomaly detection: {anomaly_detection_results['summary']}")
359
-
360
- table_rows = [[
361
- r.get("Model", ""),
362
- r.get("Contributor", ""),
363
- round(r.get("AI Score", 0.0), 3) if r.get("AI Score") is not None else 0.0,
364
- round(r.get("Real Score", 0.0), 3) if r.get("Real Score") is not None else 0.0,
365
- r.get("Label", "Error")
366
- ] for r in results]
367
-
368
- logger.info(f"Type of table_rows: {type(table_rows)}")
369
- for i, row in enumerate(table_rows):
370
- logger.info(f"Row {i} types: {[type(item) for item in row]}")
371
-
372
  consensus_html = f"<b><span style='color:{'red' if final_prediction_label == 'AI' else ('green' if final_prediction_label == 'REAL' else 'orange')}'>{final_prediction_label}</span></b>"
373
-
374
  inference_params = {
375
  "confidence_threshold": confidence_threshold,
376
  "augment_methods": augment_methods,
@@ -379,13 +346,11 @@ def ensemble_prediction(img, confidence_threshold, augment_methods, rotate_degre
379
  "sharpen_strength": sharpen_strength,
380
  "detected_context_tags": detected_context_tags
381
  }
382
-
383
  ensemble_output_data = {
384
  "final_prediction_label": final_prediction_label,
385
  "weighted_predictions": weighted_predictions,
386
  "adjusted_weights": adjusted_weights
387
  }
388
-
389
  agent_monitoring_data_log = {
390
  "ensemble_monitor": {
391
  "alerts": monitor_agent.alerts,
@@ -403,7 +368,6 @@ def ensemble_prediction(img, confidence_threshold, augment_methods, rotate_degre
403
  },
404
  "forensic_anomaly_detection": anomaly_detection_results
405
  }
406
-
407
  log_inference_data(
408
  original_image=img,
409
  inference_params=inference_params,
@@ -413,7 +377,6 @@ def ensemble_prediction(img, confidence_threshold, augment_methods, rotate_degre
413
  agent_monitoring_data=agent_monitoring_data_log,
414
  human_feedback=None
415
  )
416
-
417
  cleaned_forensics_images = []
418
  for f_img in forensics_images:
419
  if isinstance(f_img, Image.Image):
@@ -425,22 +388,18 @@ def ensemble_prediction(img, confidence_threshold, augment_methods, rotate_degre
425
  logger.warning(f"Could not convert numpy array to PIL Image for gallery: {e}")
426
  else:
427
  logger.warning(f"Unexpected type in forensic_images: {type(f_img)}. Skipping.")
428
-
429
  logger.info(f"Cleaned forensic images types: {[type(img) for img in cleaned_forensics_images]}")
430
-
431
  for i, res_dict in enumerate(results):
432
  for key in ["AI Score", "Real Score"]:
433
  value = res_dict.get(key)
434
  if isinstance(value, np.float32):
435
  res_dict[key] = float(value)
436
  logger.info(f"Converted {key} for result {i} from numpy.float32 to float.")
437
-
438
  json_results = json.dumps(results, cls=NumpyEncoder)
439
-
440
- return img_pil, cleaned_forensics_images, table_rows, json_results, consensus_html
441
 
442
  detection_model_eval_playground = gr.Interface(
443
- fn=ensemble_prediction,
444
  inputs=[
445
  gr.Image(label="Upload Image to Analyze", sources=['upload', 'webcam'], type='pil'),
446
  gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Confidence Threshold"),
@@ -462,7 +421,8 @@ detection_model_eval_playground = gr.Interface(
462
  ],
463
  title="Open Source Detection Models Found on the Hub",
464
  description="Space will be upgraded shortly; inference on all 6 models should take about 1.2~ seconds once we're back on CUDA. The Community Forensics mother of all detection models is now available for inference, head to the middle tab above this. Lots of exciting things coming up, stay tuned!",
465
- api_name="predict"
 
466
  )
467
 
468
  community_forensics_preview = gr.Interface(
 
240
  "Label": f"Error: {str(e)}"
241
  }
242
 
243
+ # --- Streaming Ensemble Prediction ---
244
+ def ensemble_prediction_stream(img, confidence_threshold, augment_methods, rotate_degrees, noise_level, sharpen_strength):
245
+ # Setup (same as before)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  if not isinstance(img, Image.Image):
247
  try:
248
  img = Image.fromarray(img)
 
256
  health_agent = SystemHealthAgent()
257
  context_agent = ContextualIntelligenceAgent()
258
  anomaly_agent = ForensicAnomalyDetectionAgent()
 
259
  health_agent.monitor_system_health()
260
 
261
  if augment_methods:
262
  img_pil, _ = augment_image(img, augment_methods, rotate_degrees, noise_level, sharpen_strength)
263
  else:
264
  img_pil = img
265
+ img_np_og = np.array(img)
266
 
267
  model_predictions_raw = {}
268
  confidence_scores = {}
269
  results = []
270
+ table_rows = []
271
 
272
+ # Stream results as each model finishes
273
  for model_id in MODEL_REGISTRY:
274
  model_start = time.time()
275
  result = infer(img_pil, model_id, confidence_threshold)
276
  model_end = time.time()
 
277
  monitor_agent.monitor_prediction(
278
  model_id,
279
  result["Label"],
280
  max(result.get("AI Score", 0.0), result.get("Real Score", 0.0)),
281
  model_end - model_start
282
  )
 
283
  model_predictions_raw[model_id] = result
284
  confidence_scores[model_id] = max(result.get("AI Score", 0.0), result.get("Real Score", 0.0))
285
  results.append(result)
286
+ table_rows.append([
287
+ result.get("Model", ""),
288
+ result.get("Contributor", ""),
289
+ round(result.get("AI Score", 0.0), 3) if result.get("AI Score") is not None else 0.0,
290
+ round(result.get("Real Score", 0.0), 3) if result.get("Real Score") is not None else 0.0,
291
+ result.get("Label", "Error")
292
+ ])
293
+ # Yield partial results: only update the table, others are None
294
+ yield None, None, table_rows, None, None
295
+
296
+ # After all models, compute the rest as before
297
  image_data_for_context = {
298
  "width": img.width,
299
  "height": img.height,
 
301
  }
302
  detected_context_tags = context_agent.infer_context_tags(image_data_for_context, model_predictions_raw)
303
  logger.info(f"Detected context tags: {detected_context_tags}")
 
304
  adjusted_weights = weight_manager.adjust_weights(model_predictions_raw, confidence_scores, context_tags=detected_context_tags)
305
+ weighted_predictions = {"AI": 0.0, "REAL": 0.0, "UNCERTAIN": 0.0}
 
 
 
 
 
 
306
  for model_id, prediction in model_predictions_raw.items():
307
  prediction_label = prediction.get("Label")
308
  if prediction_label in weighted_predictions:
309
  weighted_predictions[prediction_label] += adjusted_weights[model_id]
310
  else:
311
  logger.warning(f"Unexpected prediction label '{prediction_label}' from model '{model_id}'. Skipping its weight in consensus.")
 
312
  final_prediction_label = "UNCERTAIN"
313
  if weighted_predictions["AI"] > weighted_predictions["REAL"] and weighted_predictions["AI"] > weighted_predictions["UNCERTAIN"]:
314
  final_prediction_label = "AI"
315
  elif weighted_predictions["REAL"] > weighted_predictions["AI"] and weighted_predictions["REAL"] > weighted_predictions["UNCERTAIN"]:
316
  final_prediction_label = "REAL"
 
317
  optimization_agent.analyze_performance(final_prediction_label, None)
 
318
  gradient_image = gradient_processing(img_np_og)
319
  gradient_image2 = gradient_processing(img_np_og, intensity=45, equalize=True)
 
320
  minmax_image = minmax_process(img_np_og)
321
  minmax_image2 = minmax_process(img_np_og, radius=6)
322
  bitplane_image = bit_plane_extractor(img_pil)
 
323
  ela1 = ELA(img_np_og, quality=75, scale=50, contrast=20, linear=False, grayscale=True)
324
  ela2 = ELA(img_np_og, quality=75, scale=75, contrast=25, linear=False, grayscale=True)
325
  ela3 = ELA(img_np_og, quality=75, scale=75, contrast=25, linear=False, grayscale=False)
 
326
  forensics_images = [img_pil, ela1, ela2, ela3, gradient_image, gradient_image2, minmax_image, minmax_image2, bitplane_image]
 
327
  forensic_output_descriptions = [
328
  f"Original augmented image (PIL): {img_pil.width}x{img_pil.height}",
329
  "ELA analysis (Pass 1): Grayscale error map, quality 75.",
 
337
  ]
338
  anomaly_detection_results = anomaly_agent.analyze_forensic_outputs(forensic_output_descriptions)
339
  logger.info(f"Forensic anomaly detection: {anomaly_detection_results['summary']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  consensus_html = f"<b><span style='color:{'red' if final_prediction_label == 'AI' else ('green' if final_prediction_label == 'REAL' else 'orange')}'>{final_prediction_label}</span></b>"
 
341
  inference_params = {
342
  "confidence_threshold": confidence_threshold,
343
  "augment_methods": augment_methods,
 
346
  "sharpen_strength": sharpen_strength,
347
  "detected_context_tags": detected_context_tags
348
  }
 
349
  ensemble_output_data = {
350
  "final_prediction_label": final_prediction_label,
351
  "weighted_predictions": weighted_predictions,
352
  "adjusted_weights": adjusted_weights
353
  }
 
354
  agent_monitoring_data_log = {
355
  "ensemble_monitor": {
356
  "alerts": monitor_agent.alerts,
 
368
  },
369
  "forensic_anomaly_detection": anomaly_detection_results
370
  }
 
371
  log_inference_data(
372
  original_image=img,
373
  inference_params=inference_params,
 
377
  agent_monitoring_data=agent_monitoring_data_log,
378
  human_feedback=None
379
  )
 
380
  cleaned_forensics_images = []
381
  for f_img in forensics_images:
382
  if isinstance(f_img, Image.Image):
 
388
  logger.warning(f"Could not convert numpy array to PIL Image for gallery: {e}")
389
  else:
390
  logger.warning(f"Unexpected type in forensic_images: {type(f_img)}. Skipping.")
 
391
  logger.info(f"Cleaned forensic images types: {[type(img) for img in cleaned_forensics_images]}")
 
392
  for i, res_dict in enumerate(results):
393
  for key in ["AI Score", "Real Score"]:
394
  value = res_dict.get(key)
395
  if isinstance(value, np.float32):
396
  res_dict[key] = float(value)
397
  logger.info(f"Converted {key} for result {i} from numpy.float32 to float.")
 
398
  json_results = json.dumps(results, cls=NumpyEncoder)
399
+ yield img_pil, cleaned_forensics_images, table_rows, json_results, consensus_html
 
400
 
401
  detection_model_eval_playground = gr.Interface(
402
+ fn=ensemble_prediction_stream,
403
  inputs=[
404
  gr.Image(label="Upload Image to Analyze", sources=['upload', 'webcam'], type='pil'),
405
  gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Confidence Threshold"),
 
421
  ],
422
  title="Open Source Detection Models Found on the Hub",
423
  description="Space will be upgraded shortly; inference on all 6 models should take about 1.2~ seconds once we're back on CUDA. The Community Forensics mother of all detection models is now available for inference, head to the middle tab above this. Lots of exciting things coming up, stay tuned!",
424
+ api_name="predict",
425
+ live=True # Enable streaming
426
  )
427
 
428
  community_forensics_preview = gr.Interface(