Spaces:
Sleeping
Sleeping
import cv2 | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from image_processor import ImageProcessor | |
from heatmap_generator import HeatmapGenerator | |
class CombinedVisualizer: | |
def __init__(self): | |
""" | |
Initialize the combined visualizer for creating overlaid threat visualizations | |
""" | |
self.image_processor = ImageProcessor() | |
self.heatmap_generator = HeatmapGenerator() | |
def create_combined_visualization(self, image_pair_results, output_path, alpha_diff=0.4, alpha_low=0.3, alpha_medium=0.4, dpi=300): | |
""" | |
Create a combined visualization that overlays difference image, low and medium threat heatmaps, | |
and bounding boxes on top of each other. | |
Args: | |
image_pair_results: Dictionary containing all processing results | |
output_path: Path to save the visualization | |
alpha_diff: Transparency for difference image overlay | |
alpha_low: Transparency for low threat heatmap overlay | |
alpha_medium: Transparency for medium threat heatmap overlay | |
dpi: Resolution for saved image | |
Returns: | |
Path to the generated visualization | |
""" | |
# Extract required components from results | |
original_image = image_pair_results['original_image'] | |
difference_image = image_pair_results['difference_image'] | |
bounding_boxes = image_pair_results['bounding_boxes'] | |
multi_heatmaps = image_pair_results.get('multi_heatmaps', {}) | |
labeled_regions = [r for r in image_pair_results.get('labeled_regions', []) | |
if 'bbox' in r and 'threat_level' in r] | |
# If labeled_regions not provided, extract from bounding boxes | |
if not labeled_regions and 'threat_summary' in image_pair_results: | |
# Create simplified labeled regions from bounding boxes | |
for bbox in bounding_boxes: | |
# Default to medium threat if specific threat level not available | |
labeled_regions.append({ | |
'bbox': bbox, | |
'threat_level': 'medium', | |
'difference_percentage': 50 # Default value | |
}) | |
# Start with a copy of the original image | |
combined_image = original_image.copy() | |
# 1. Overlay the difference image with transparency | |
# Convert difference image to RGB if it's grayscale | |
if len(difference_image.shape) == 2 or difference_image.shape[2] == 1: | |
diff_colored = cv2.applyColorMap(difference_image, cv2.COLORMAP_HOT) | |
diff_colored = cv2.cvtColor(diff_colored, cv2.COLOR_BGR2RGB) | |
else: | |
diff_colored = difference_image | |
# Overlay difference image | |
combined_image = self.image_processor.overlay_images(combined_image, diff_colored, alpha_diff) | |
# 2. Overlay low threat heatmap if available | |
if 'low' in multi_heatmaps: | |
low_heatmap = multi_heatmaps['low'] | |
combined_image = self.image_processor.overlay_images(combined_image, low_heatmap, alpha_low) | |
# 3. Overlay medium threat heatmap if available | |
if 'medium' in multi_heatmaps: | |
medium_heatmap = multi_heatmaps['medium'] | |
combined_image = self.image_processor.overlay_images(combined_image, medium_heatmap, alpha_medium) | |
# 4. Draw bounding boxes with threat level colors | |
threat_colors = { | |
'low': (0, 255, 0), # Green | |
'medium': (0, 165, 255), # Orange | |
'high': (0, 0, 255) # Red | |
} | |
# Draw bounding boxes based on threat levels | |
for region in labeled_regions: | |
bbox = region['bbox'] | |
threat_level = region['threat_level'] | |
x, y, w, h = bbox | |
# Get color for this threat level (default to red if not found) | |
color = threat_colors.get(threat_level, (0, 0, 255)) | |
# Convert BGR to RGB for matplotlib | |
color_rgb = (color[2]/255, color[1]/255, color[0]/255) | |
# Draw rectangle with threat level color | |
cv2.rectangle(combined_image, (x, y), (x + w, y + h), color, 2) | |
# Add label text with threat level | |
if 'difference_percentage' in region: | |
label_text = f"{threat_level.upper()}: {region['difference_percentage']:.1f}%" | |
else: | |
label_text = f"{threat_level.upper()}" | |
cv2.putText(combined_image, label_text, (x, y - 10), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) | |
# Save the combined visualization | |
plt.figure(figsize=(12, 8)) | |
plt.imshow(combined_image) | |
plt.title('Combined Threat Visualization') | |
plt.axis('off') | |
plt.tight_layout() | |
plt.savefig(output_path, dpi=dpi, bbox_inches='tight') | |
plt.close() | |
# Also save the raw image for potential further processing | |
raw_output_path = output_path.replace('.png', '_raw.png') | |
self.image_processor.save_image(combined_image, raw_output_path) | |
return output_path | |
def create_combined_visualization_from_files(self, original_path, difference_path, | |
low_heatmap_path, medium_heatmap_path, | |
bounding_boxes, output_path, | |
alpha_diff=0.4, alpha_low=0.3, alpha_medium=0.4): | |
""" | |
Create a combined visualization from individual image files | |
Args: | |
original_path: Path to original image | |
difference_path: Path to difference image | |
low_heatmap_path: Path to low threat heatmap | |
medium_heatmap_path: Path to medium threat heatmap | |
bounding_boxes: List of bounding boxes (x, y, w, h) | |
output_path: Path to save the visualization | |
alpha_diff: Transparency for difference image overlay | |
alpha_low: Transparency for low threat heatmap overlay | |
alpha_medium: Transparency for medium threat heatmap overlay | |
Returns: | |
Path to the generated visualization | |
""" | |
# Load images | |
original_image = self.image_processor.load_image(original_path) | |
difference_image = self.image_processor.load_image(difference_path) | |
# Load heatmaps if paths are provided | |
low_heatmap = None | |
medium_heatmap = None | |
if low_heatmap_path: | |
low_heatmap = self.image_processor.load_image(low_heatmap_path) | |
if medium_heatmap_path: | |
medium_heatmap = self.image_processor.load_image(medium_heatmap_path) | |
# Create a mock image_pair_results dictionary | |
image_pair_results = { | |
'original_image': original_image, | |
'difference_image': difference_image, | |
'bounding_boxes': bounding_boxes, | |
'multi_heatmaps': {} | |
} | |
if low_heatmap is not None: | |
image_pair_results['multi_heatmaps']['low'] = low_heatmap | |
if medium_heatmap is not None: | |
image_pair_results['multi_heatmaps']['medium'] = medium_heatmap | |
# Call the main visualization method | |
return self.create_combined_visualization( | |
image_pair_results, output_path, alpha_diff, alpha_low, alpha_medium | |
) | |
# Example usage | |
if __name__ == "__main__": | |
import os | |
from deepfake_detector import DeepfakeDetector | |
from labeling import ThreatLabeler | |
# Initialize components | |
detector = DeepfakeDetector() | |
labeler = ThreatLabeler() | |
heatmap_gen = HeatmapGenerator() | |
img_processor = ImageProcessor() | |
visualizer = CombinedVisualizer() | |
# Example paths | |
image1_path = "path/to/original.jpg" | |
image2_path = "path/to/modified.jpg" | |
output_dir = "path/to/output" | |
# Ensure output directory exists | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
# Process images | |
results = detector.process_image_pair(image1_path, image2_path) | |
# Label regions | |
original_image = img_processor.load_image(image1_path) | |
labeled_image, labeled_regions = labeler.label_regions( | |
original_image, results['difference_image'], results['bounding_boxes'] | |
) | |
# Generate multi-level heatmaps | |
multi_heatmaps = heatmap_gen.generate_multi_level_heatmap(original_image, labeled_regions) | |
# Prepare results for combined visualization | |
image_pair_results = { | |
'original_image': original_image, | |
'difference_image': results['difference_image'], | |
'bounding_boxes': results['bounding_boxes'], | |
'multi_heatmaps': multi_heatmaps, | |
'labeled_regions': labeled_regions | |
} | |
# Create combined visualization | |
output_path = os.path.join(output_dir, "combined_visualization.png") | |
visualizer.create_combined_visualization(image_pair_results, output_path) | |
print(f"Combined visualization saved to: {output_path}") |