feat: refactor prediction functions and enhance image loading capabilities for improved processing and noise estimation
Browse files- app.py +38 -17
- forensics/__init__.py +2 -2
- forensics/wavelet.py +1 -1
- utils/load.py +51 -0
app.py
CHANGED
@@ -13,9 +13,10 @@ from utils.utils import softmax, augment_image
|
|
13 |
from forensics.gradient import gradient_processing
|
14 |
from forensics.minmax import minmax_process
|
15 |
from forensics.ela import ELA
|
16 |
-
from forensics.wavelet import
|
17 |
from forensics.bitplane import bit_plane_extractor
|
18 |
from utils.hf_logger import log_inference_data
|
|
|
19 |
from agents.ensemble_team import EnsembleMonitorAgent, WeightOptimizationAgent, SystemHealthAgent
|
20 |
from agents.smart_agents import ContextualIntelligenceAgent, ForensicAnomalyDetectionAgent
|
21 |
from utils.registry import register_model, MODEL_REGISTRY, ModelEntry
|
@@ -191,9 +192,10 @@ def simple_prediction(img):
|
|
191 |
img_byte_arr = io.BytesIO()
|
192 |
img.save(img_byte_arr, format='PNG') # Using PNG for lossless conversion, can be JPEG if preferred
|
193 |
img_byte_arr.seek(0) # Rewind to the beginning of the stream
|
|
|
194 |
|
195 |
result = client.predict(
|
196 |
-
input_image=handle_file(
|
197 |
api_name="/simple_predict"
|
198 |
)
|
199 |
return result
|
@@ -247,15 +249,34 @@ def infer(image: Image.Image, model_id: str, confidence_threshold: float = 0.75)
|
|
247 |
"Label": f"Error: {str(e)}"
|
248 |
}
|
249 |
|
250 |
-
|
251 |
-
|
252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
if not isinstance(img, Image.Image):
|
254 |
try:
|
255 |
img = Image.fromarray(img)
|
256 |
except Exception as e:
|
257 |
logger.error(f"Error converting input image to PIL: {e}")
|
258 |
-
raise
|
|
|
|
|
|
|
|
|
259 |
|
260 |
monitor_agent = EnsembleMonitorAgent()
|
261 |
weight_manager = ModelWeightManager(strongest_model_id="simple_prediction")
|
@@ -406,7 +427,7 @@ def ensemble_prediction_stream(img, confidence_threshold, augment_methods, rotat
|
|
406 |
yield img_pil, cleaned_forensics_images, table_rows, json_results, consensus_html
|
407 |
|
408 |
detection_model_eval_playground = gr.Interface(
|
409 |
-
fn=
|
410 |
inputs=[
|
411 |
gr.Image(label="Upload Image to Analyze", sources=['upload', 'webcam'], type='pil'),
|
412 |
gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Confidence Threshold"),
|
@@ -426,8 +447,8 @@ detection_model_eval_playground = gr.Interface(
|
|
426 |
gr.JSON(label="Raw Model Results", visible=False),
|
427 |
gr.Markdown(label="Consensus", value="")
|
428 |
],
|
429 |
-
title="
|
430 |
-
description="
|
431 |
api_name="predict",
|
432 |
live=True # Enable streaming
|
433 |
)
|
@@ -436,9 +457,9 @@ community_forensics_preview = gr.Interface(
|
|
436 |
fn=lambda: gr.load("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview", src="spaces"),
|
437 |
inputs=None,
|
438 |
outputs=gr.HTML(), # or gr.Markdown() if it's just text
|
439 |
-
title="
|
440 |
-
description="
|
441 |
-
api_name="
|
442 |
)
|
443 |
|
444 |
leaderboard = gr.Interface(
|
@@ -453,13 +474,13 @@ simple_predict_interface = gr.Interface(
|
|
453 |
fn=simple_prediction,
|
454 |
inputs=gr.Image(type="filepath"),
|
455 |
outputs=gr.Text(),
|
456 |
-
title="
|
457 |
-
description="",
|
458 |
api_name="simple_predict"
|
459 |
)
|
460 |
|
461 |
-
|
462 |
-
fn=
|
463 |
inputs=[gr.Image(type="pil"), gr.Slider(1, 32, value=8, step=1, label="Block Size")],
|
464 |
outputs=gr.Image(type="pil"),
|
465 |
title="Wavelet-Based Noise Analysis",
|
@@ -529,7 +550,7 @@ demo = gr.TabbedInterface(
|
|
529 |
[
|
530 |
detection_model_eval_playground,
|
531 |
simple_predict_interface,
|
532 |
-
|
533 |
bit_plane_interface,
|
534 |
ela_interface,
|
535 |
gradient_processing_interface,
|
|
|
13 |
from forensics.gradient import gradient_processing
|
14 |
from forensics.minmax import minmax_process
|
15 |
from forensics.ela import ELA
|
16 |
+
from forensics.wavelet import noise_estimation
|
17 |
from forensics.bitplane import bit_plane_extractor
|
18 |
from utils.hf_logger import log_inference_data
|
19 |
+
from utils.load import load_image
|
20 |
from agents.ensemble_team import EnsembleMonitorAgent, WeightOptimizationAgent, SystemHealthAgent
|
21 |
from agents.smart_agents import ContextualIntelligenceAgent, ForensicAnomalyDetectionAgent
|
22 |
from utils.registry import register_model, MODEL_REGISTRY, ModelEntry
|
|
|
192 |
img_byte_arr = io.BytesIO()
|
193 |
img.save(img_byte_arr, format='PNG') # Using PNG for lossless conversion, can be JPEG if preferred
|
194 |
img_byte_arr.seek(0) # Rewind to the beginning of the stream
|
195 |
+
im = load_image(img)
|
196 |
|
197 |
result = client.predict(
|
198 |
+
input_image=handle_file(im),
|
199 |
api_name="/simple_predict"
|
200 |
)
|
201 |
return result
|
|
|
249 |
"Label": f"Error: {str(e)}"
|
250 |
}
|
251 |
|
252 |
+
def full_prediction(img, confidence_threshold, augment_methods, rotate_degrees, noise_level, sharpen_strength):
|
253 |
+
"""Full prediction run, with a team of ensembles and agents.
|
254 |
+
|
255 |
+
Args:
|
256 |
+
img (url: str, Image.Image, np.ndarray): The input image to classify.
|
257 |
+
confidence_threshold (float, optional): The confidence threshold for classification. Defaults to 0.75.
|
258 |
+
augment_methods (list, optional): The augmentation methods to use.
|
259 |
+
rotate_degrees (int, optional): The degrees to rotate the image.
|
260 |
+
noise_level (int, optional): The noise level to use.
|
261 |
+
sharpen_strength (int, optional): The sharpen strength to use.
|
262 |
+
|
263 |
+
Returns:
|
264 |
+
dict: A dictionary containing the model details, classification scores, and label.
|
265 |
+
"""
|
266 |
+
# Ensure img is a PIL Image object
|
267 |
+
if img is None:
|
268 |
+
raise gr.Error("No image provided. Please upload an image to analyze.")
|
269 |
+
|
270 |
if not isinstance(img, Image.Image):
|
271 |
try:
|
272 |
img = Image.fromarray(img)
|
273 |
except Exception as e:
|
274 |
logger.error(f"Error converting input image to PIL: {e}")
|
275 |
+
raise gr.Error("Input image could not be converted to a valid image format. Please try another image.")
|
276 |
+
|
277 |
+
# Ensure image is in RGB format for consistent processing
|
278 |
+
if img.mode != 'RGB':
|
279 |
+
img = img.convert('RGB')
|
280 |
|
281 |
monitor_agent = EnsembleMonitorAgent()
|
282 |
weight_manager = ModelWeightManager(strongest_model_id="simple_prediction")
|
|
|
427 |
yield img_pil, cleaned_forensics_images, table_rows, json_results, consensus_html
|
428 |
|
429 |
detection_model_eval_playground = gr.Interface(
|
430 |
+
fn=full_prediction,
|
431 |
inputs=[
|
432 |
gr.Image(label="Upload Image to Analyze", sources=['upload', 'webcam'], type='pil'),
|
433 |
gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Confidence Threshold"),
|
|
|
447 |
gr.JSON(label="Raw Model Results", visible=False),
|
448 |
gr.Markdown(label="Consensus", value="")
|
449 |
],
|
450 |
+
title="Multi-Model Ensemble + Agentic Coordinated Deepfake Detection",
|
451 |
+
description="The detection of AI-generated images has entered a critical inflection point. While existing solutions struggle with outdated datasets and inflated claims, our approach prioritizes agility, community collaboration, and an offensive approach to deepfake detection.",
|
452 |
api_name="predict",
|
453 |
live=True # Enable streaming
|
454 |
)
|
|
|
457 |
fn=lambda: gr.load("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview", src="spaces"),
|
458 |
inputs=None,
|
459 |
outputs=gr.HTML(), # or gr.Markdown() if it's just text
|
460 |
+
title="Quick and simple prediction by our strongest model.",
|
461 |
+
description="No ensemble, no context, no agents, just a quick and simple prediction by our strongest model.",
|
462 |
+
api_name="quick_predict"
|
463 |
)
|
464 |
|
465 |
leaderboard = gr.Interface(
|
|
|
474 |
fn=simple_prediction,
|
475 |
inputs=gr.Image(type="filepath"),
|
476 |
outputs=gr.Text(),
|
477 |
+
title="Quick and simple prediction by our strongest model.",
|
478 |
+
description="No ensemble, no context, no agents, just a quick and simple prediction by our strongest model.",
|
479 |
api_name="simple_predict"
|
480 |
)
|
481 |
|
482 |
+
noise_estimation_interface = gr.Interface(
|
483 |
+
fn=noise_estimation,
|
484 |
inputs=[gr.Image(type="pil"), gr.Slider(1, 32, value=8, step=1, label="Block Size")],
|
485 |
outputs=gr.Image(type="pil"),
|
486 |
title="Wavelet-Based Noise Analysis",
|
|
|
550 |
[
|
551 |
detection_model_eval_playground,
|
552 |
simple_predict_interface,
|
553 |
+
noise_estimation_interface,
|
554 |
bit_plane_interface,
|
555 |
ela_interface,
|
556 |
gradient_processing_interface,
|
forensics/__init__.py
CHANGED
@@ -3,7 +3,7 @@ from .ela import ELA
|
|
3 |
# from .exif import exif_full_dump
|
4 |
from .gradient import gradient_processing
|
5 |
from .minmax import minmax_process
|
6 |
-
from .wavelet import
|
7 |
|
8 |
__all__ = [
|
9 |
'bit_plane_extractor',
|
@@ -11,5 +11,5 @@ __all__ = [
|
|
11 |
# 'exif_full_dump',
|
12 |
'gradient_processing',
|
13 |
'minmax_process',
|
14 |
-
'
|
15 |
]
|
|
|
3 |
# from .exif import exif_full_dump
|
4 |
from .gradient import gradient_processing
|
5 |
from .minmax import minmax_process
|
6 |
+
from .wavelet import noise_estimation
|
7 |
|
8 |
__all__ = [
|
9 |
'bit_plane_extractor',
|
|
|
11 |
# 'exif_full_dump',
|
12 |
'gradient_processing',
|
13 |
'minmax_process',
|
14 |
+
'noise_estimation'
|
15 |
]
|
forensics/wavelet.py
CHANGED
@@ -3,7 +3,7 @@ import pywt
|
|
3 |
import cv2
|
4 |
from PIL import Image
|
5 |
|
6 |
-
def
|
7 |
"""Estimate local noise using wavelet blocking. Returns a PIL image of the noise map."""
|
8 |
im = np.array(image.convert('L'))
|
9 |
y = np.double(im)
|
|
|
3 |
import cv2
|
4 |
from PIL import Image
|
5 |
|
6 |
+
def noise_estimation(image: Image.Image, blocksize: int = 8) -> Image.Image:
|
7 |
"""Estimate local noise using wavelet blocking. Returns a PIL image of the noise map."""
|
8 |
im = np.array(image.convert('L'))
|
9 |
y = np.double(im)
|
utils/load.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
4 |
+
from urllib.parse import unquote, urlparse
|
5 |
+
|
6 |
+
import PIL.Image
|
7 |
+
import PIL.ImageOps
|
8 |
+
import requests
|
9 |
+
|
10 |
+
def load_image(
|
11 |
+
image: Union[str, PIL.Image.Image], convert_method: Optional[Callable[[PIL.Image.Image], PIL.Image.Image]] = None
|
12 |
+
) -> PIL.Image.Image:
|
13 |
+
"""
|
14 |
+
Loads `image` to a PIL Image.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
image (`str` or `PIL.Image.Image`):
|
18 |
+
The image to convert to the PIL Image format.
|
19 |
+
convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*):
|
20 |
+
A conversion method to apply to the image after loading it. When set to `None` the image will be converted
|
21 |
+
"RGB".
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
`PIL.Image.Image`:
|
25 |
+
A PIL Image.
|
26 |
+
"""
|
27 |
+
if isinstance(image, str):
|
28 |
+
if image.startswith("http://") or image.startswith("https://"):
|
29 |
+
image = PIL.Image.open(requests.get(image, stream=True, timeout=600).raw)
|
30 |
+
elif os.path.isfile(image):
|
31 |
+
image = PIL.Image.open(image)
|
32 |
+
else:
|
33 |
+
raise ValueError(
|
34 |
+
f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path."
|
35 |
+
)
|
36 |
+
elif isinstance(image, PIL.Image.Image):
|
37 |
+
image = image
|
38 |
+
else:
|
39 |
+
raise ValueError(
|
40 |
+
"Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image."
|
41 |
+
)
|
42 |
+
|
43 |
+
image = PIL.ImageOps.exif_transpose(image)
|
44 |
+
|
45 |
+
if convert_method is not None:
|
46 |
+
image = convert_method(image)
|
47 |
+
else:
|
48 |
+
image = image.convert("RGB")
|
49 |
+
|
50 |
+
return image
|
51 |
+
|