LPX55 commited on
Commit
5342779
·
1 Parent(s): 1146644

feat: refactor prediction functions and enhance image loading capabilities for improved processing and noise estimation

Browse files
Files changed (4) hide show
  1. app.py +38 -17
  2. forensics/__init__.py +2 -2
  3. forensics/wavelet.py +1 -1
  4. utils/load.py +51 -0
app.py CHANGED
@@ -13,9 +13,10 @@ from utils.utils import softmax, augment_image
13
  from forensics.gradient import gradient_processing
14
  from forensics.minmax import minmax_process
15
  from forensics.ela import ELA
16
- from forensics.wavelet import wavelet_blocking_noise_estimation
17
  from forensics.bitplane import bit_plane_extractor
18
  from utils.hf_logger import log_inference_data
 
19
  from agents.ensemble_team import EnsembleMonitorAgent, WeightOptimizationAgent, SystemHealthAgent
20
  from agents.smart_agents import ContextualIntelligenceAgent, ForensicAnomalyDetectionAgent
21
  from utils.registry import register_model, MODEL_REGISTRY, ModelEntry
@@ -191,9 +192,10 @@ def simple_prediction(img):
191
  img_byte_arr = io.BytesIO()
192
  img.save(img_byte_arr, format='PNG') # Using PNG for lossless conversion, can be JPEG if preferred
193
  img_byte_arr.seek(0) # Rewind to the beginning of the stream
 
194
 
195
  result = client.predict(
196
- input_image=handle_file(img_byte_arr),
197
  api_name="/simple_predict"
198
  )
199
  return result
@@ -247,15 +249,34 @@ def infer(image: Image.Image, model_id: str, confidence_threshold: float = 0.75)
247
  "Label": f"Error: {str(e)}"
248
  }
249
 
250
- # --- Streaming Ensemble Prediction ---
251
- def ensemble_prediction_stream(img, confidence_threshold, augment_methods, rotate_degrees, noise_level, sharpen_strength):
252
- # Setup (same as before)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  if not isinstance(img, Image.Image):
254
  try:
255
  img = Image.fromarray(img)
256
  except Exception as e:
257
  logger.error(f"Error converting input image to PIL: {e}")
258
- raise ValueError("Input image could not be converted to PIL Image.")
 
 
 
 
259
 
260
  monitor_agent = EnsembleMonitorAgent()
261
  weight_manager = ModelWeightManager(strongest_model_id="simple_prediction")
@@ -406,7 +427,7 @@ def ensemble_prediction_stream(img, confidence_threshold, augment_methods, rotat
406
  yield img_pil, cleaned_forensics_images, table_rows, json_results, consensus_html
407
 
408
  detection_model_eval_playground = gr.Interface(
409
- fn=ensemble_prediction_stream,
410
  inputs=[
411
  gr.Image(label="Upload Image to Analyze", sources=['upload', 'webcam'], type='pil'),
412
  gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Confidence Threshold"),
@@ -426,8 +447,8 @@ detection_model_eval_playground = gr.Interface(
426
  gr.JSON(label="Raw Model Results", visible=False),
427
  gr.Markdown(label="Consensus", value="")
428
  ],
429
- title="Open Source Detection Models Found on the Hub",
430
- description="Space will be upgraded shortly; inference on all 6 models should take about 1.2~ seconds once we're back on CUDA. The Community Forensics mother of all detection models is now available for inference, head to the middle tab above this. Lots of exciting things coming up, stay tuned!",
431
  api_name="predict",
432
  live=True # Enable streaming
433
  )
@@ -436,9 +457,9 @@ community_forensics_preview = gr.Interface(
436
  fn=lambda: gr.load("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview", src="spaces"),
437
  inputs=None,
438
  outputs=gr.HTML(), # or gr.Markdown() if it's just text
439
- title="Community Forensics Preview",
440
- description="Community Forensics Preview coming soon!",
441
- api_name="community_forensics"
442
  )
443
 
444
  leaderboard = gr.Interface(
@@ -453,13 +474,13 @@ simple_predict_interface = gr.Interface(
453
  fn=simple_prediction,
454
  inputs=gr.Image(type="filepath"),
455
  outputs=gr.Text(),
456
- title="Simple and Fast Prediction",
457
- description="",
458
  api_name="simple_predict"
459
  )
460
 
461
- wavelet_noise_estimation = gr.Interface(
462
- fn=wavelet_blocking_noise_estimation,
463
  inputs=[gr.Image(type="pil"), gr.Slider(1, 32, value=8, step=1, label="Block Size")],
464
  outputs=gr.Image(type="pil"),
465
  title="Wavelet-Based Noise Analysis",
@@ -529,7 +550,7 @@ demo = gr.TabbedInterface(
529
  [
530
  detection_model_eval_playground,
531
  simple_predict_interface,
532
- wavelet_noise_estimation,
533
  bit_plane_interface,
534
  ela_interface,
535
  gradient_processing_interface,
 
13
  from forensics.gradient import gradient_processing
14
  from forensics.minmax import minmax_process
15
  from forensics.ela import ELA
16
+ from forensics.wavelet import noise_estimation
17
  from forensics.bitplane import bit_plane_extractor
18
  from utils.hf_logger import log_inference_data
19
+ from utils.load import load_image
20
  from agents.ensemble_team import EnsembleMonitorAgent, WeightOptimizationAgent, SystemHealthAgent
21
  from agents.smart_agents import ContextualIntelligenceAgent, ForensicAnomalyDetectionAgent
22
  from utils.registry import register_model, MODEL_REGISTRY, ModelEntry
 
192
  img_byte_arr = io.BytesIO()
193
  img.save(img_byte_arr, format='PNG') # Using PNG for lossless conversion, can be JPEG if preferred
194
  img_byte_arr.seek(0) # Rewind to the beginning of the stream
195
+ im = load_image(img)
196
 
197
  result = client.predict(
198
+ input_image=handle_file(im),
199
  api_name="/simple_predict"
200
  )
201
  return result
 
249
  "Label": f"Error: {str(e)}"
250
  }
251
 
252
+ def full_prediction(img, confidence_threshold, augment_methods, rotate_degrees, noise_level, sharpen_strength):
253
+ """Full prediction run, with a team of ensembles and agents.
254
+
255
+ Args:
256
+ img (url: str, Image.Image, np.ndarray): The input image to classify.
257
+ confidence_threshold (float, optional): The confidence threshold for classification. Defaults to 0.75.
258
+ augment_methods (list, optional): The augmentation methods to use.
259
+ rotate_degrees (int, optional): The degrees to rotate the image.
260
+ noise_level (int, optional): The noise level to use.
261
+ sharpen_strength (int, optional): The sharpen strength to use.
262
+
263
+ Returns:
264
+ dict: A dictionary containing the model details, classification scores, and label.
265
+ """
266
+ # Ensure img is a PIL Image object
267
+ if img is None:
268
+ raise gr.Error("No image provided. Please upload an image to analyze.")
269
+
270
  if not isinstance(img, Image.Image):
271
  try:
272
  img = Image.fromarray(img)
273
  except Exception as e:
274
  logger.error(f"Error converting input image to PIL: {e}")
275
+ raise gr.Error("Input image could not be converted to a valid image format. Please try another image.")
276
+
277
+ # Ensure image is in RGB format for consistent processing
278
+ if img.mode != 'RGB':
279
+ img = img.convert('RGB')
280
 
281
  monitor_agent = EnsembleMonitorAgent()
282
  weight_manager = ModelWeightManager(strongest_model_id="simple_prediction")
 
427
  yield img_pil, cleaned_forensics_images, table_rows, json_results, consensus_html
428
 
429
  detection_model_eval_playground = gr.Interface(
430
+ fn=full_prediction,
431
  inputs=[
432
  gr.Image(label="Upload Image to Analyze", sources=['upload', 'webcam'], type='pil'),
433
  gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Confidence Threshold"),
 
447
  gr.JSON(label="Raw Model Results", visible=False),
448
  gr.Markdown(label="Consensus", value="")
449
  ],
450
+ title="Multi-Model Ensemble + Agentic Coordinated Deepfake Detection",
451
+ description="The detection of AI-generated images has entered a critical inflection point. While existing solutions struggle with outdated datasets and inflated claims, our approach prioritizes agility, community collaboration, and an offensive approach to deepfake detection.",
452
  api_name="predict",
453
  live=True # Enable streaming
454
  )
 
457
  fn=lambda: gr.load("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview", src="spaces"),
458
  inputs=None,
459
  outputs=gr.HTML(), # or gr.Markdown() if it's just text
460
+ title="Quick and simple prediction by our strongest model.",
461
+ description="No ensemble, no context, no agents, just a quick and simple prediction by our strongest model.",
462
+ api_name="quick_predict"
463
  )
464
 
465
  leaderboard = gr.Interface(
 
474
  fn=simple_prediction,
475
  inputs=gr.Image(type="filepath"),
476
  outputs=gr.Text(),
477
+ title="Quick and simple prediction by our strongest model.",
478
+ description="No ensemble, no context, no agents, just a quick and simple prediction by our strongest model.",
479
  api_name="simple_predict"
480
  )
481
 
482
+ noise_estimation_interface = gr.Interface(
483
+ fn=noise_estimation,
484
  inputs=[gr.Image(type="pil"), gr.Slider(1, 32, value=8, step=1, label="Block Size")],
485
  outputs=gr.Image(type="pil"),
486
  title="Wavelet-Based Noise Analysis",
 
550
  [
551
  detection_model_eval_playground,
552
  simple_predict_interface,
553
+ noise_estimation_interface,
554
  bit_plane_interface,
555
  ela_interface,
556
  gradient_processing_interface,
forensics/__init__.py CHANGED
@@ -3,7 +3,7 @@ from .ela import ELA
3
  # from .exif import exif_full_dump
4
  from .gradient import gradient_processing
5
  from .minmax import minmax_process
6
- from .wavelet import wavelet_blocking_noise_estimation
7
 
8
  __all__ = [
9
  'bit_plane_extractor',
@@ -11,5 +11,5 @@ __all__ = [
11
  # 'exif_full_dump',
12
  'gradient_processing',
13
  'minmax_process',
14
- 'wavelet_blocking_noise_estimation'
15
  ]
 
3
  # from .exif import exif_full_dump
4
  from .gradient import gradient_processing
5
  from .minmax import minmax_process
6
+ from .wavelet import noise_estimation
7
 
8
  __all__ = [
9
  'bit_plane_extractor',
 
11
  # 'exif_full_dump',
12
  'gradient_processing',
13
  'minmax_process',
14
+ 'noise_estimation'
15
  ]
forensics/wavelet.py CHANGED
@@ -3,7 +3,7 @@ import pywt
3
  import cv2
4
  from PIL import Image
5
 
6
- def wavelet_blocking_noise_estimation(image: Image.Image, blocksize: int = 8) -> Image.Image:
7
  """Estimate local noise using wavelet blocking. Returns a PIL image of the noise map."""
8
  im = np.array(image.convert('L'))
9
  y = np.double(im)
 
3
  import cv2
4
  from PIL import Image
5
 
6
+ def noise_estimation(image: Image.Image, blocksize: int = 8) -> Image.Image:
7
  """Estimate local noise using wavelet blocking. Returns a PIL image of the noise map."""
8
  im = np.array(image.convert('L'))
9
  y = np.double(im)
utils/load.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from typing import Any, Callable, List, Optional, Tuple, Union
4
+ from urllib.parse import unquote, urlparse
5
+
6
+ import PIL.Image
7
+ import PIL.ImageOps
8
+ import requests
9
+
10
+ def load_image(
11
+ image: Union[str, PIL.Image.Image], convert_method: Optional[Callable[[PIL.Image.Image], PIL.Image.Image]] = None
12
+ ) -> PIL.Image.Image:
13
+ """
14
+ Loads `image` to a PIL Image.
15
+
16
+ Args:
17
+ image (`str` or `PIL.Image.Image`):
18
+ The image to convert to the PIL Image format.
19
+ convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*):
20
+ A conversion method to apply to the image after loading it. When set to `None` the image will be converted
21
+ "RGB".
22
+
23
+ Returns:
24
+ `PIL.Image.Image`:
25
+ A PIL Image.
26
+ """
27
+ if isinstance(image, str):
28
+ if image.startswith("http://") or image.startswith("https://"):
29
+ image = PIL.Image.open(requests.get(image, stream=True, timeout=600).raw)
30
+ elif os.path.isfile(image):
31
+ image = PIL.Image.open(image)
32
+ else:
33
+ raise ValueError(
34
+ f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path."
35
+ )
36
+ elif isinstance(image, PIL.Image.Image):
37
+ image = image
38
+ else:
39
+ raise ValueError(
40
+ "Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image."
41
+ )
42
+
43
+ image = PIL.ImageOps.exif_transpose(image)
44
+
45
+ if convert_method is not None:
46
+ image = convert_method(image)
47
+ else:
48
+ image = image.convert("RGB")
49
+
50
+ return image
51
+