heatmap / combined_visualization.py
noumanjavaid's picture
Upload 11 files
9e629a3 verified
raw
history blame
9.27 kB
import cv2
import numpy as np
import matplotlib.pyplot as plt
from image_processor import ImageProcessor
from heatmap_generator import HeatmapGenerator
class CombinedVisualizer:
def __init__(self):
"""
Initialize the combined visualizer for creating overlaid threat visualizations
"""
self.image_processor = ImageProcessor()
self.heatmap_generator = HeatmapGenerator()
def create_combined_visualization(self, image_pair_results, output_path, alpha_diff=0.4, alpha_low=0.3, alpha_medium=0.4, dpi=300):
"""
Create a combined visualization that overlays difference image, low and medium threat heatmaps,
and bounding boxes on top of each other.
Args:
image_pair_results: Dictionary containing all processing results
output_path: Path to save the visualization
alpha_diff: Transparency for difference image overlay
alpha_low: Transparency for low threat heatmap overlay
alpha_medium: Transparency for medium threat heatmap overlay
dpi: Resolution for saved image
Returns:
Path to the generated visualization
"""
# Extract required components from results
original_image = image_pair_results['original_image']
difference_image = image_pair_results['difference_image']
bounding_boxes = image_pair_results['bounding_boxes']
multi_heatmaps = image_pair_results.get('multi_heatmaps', {})
labeled_regions = [r for r in image_pair_results.get('labeled_regions', [])
if 'bbox' in r and 'threat_level' in r]
# If labeled_regions not provided, extract from bounding boxes
if not labeled_regions and 'threat_summary' in image_pair_results:
# Create simplified labeled regions from bounding boxes
for bbox in bounding_boxes:
# Default to medium threat if specific threat level not available
labeled_regions.append({
'bbox': bbox,
'threat_level': 'medium',
'difference_percentage': 50 # Default value
})
# Start with a copy of the original image
combined_image = original_image.copy()
# 1. Overlay the difference image with transparency
# Convert difference image to RGB if it's grayscale
if len(difference_image.shape) == 2 or difference_image.shape[2] == 1:
diff_colored = cv2.applyColorMap(difference_image, cv2.COLORMAP_HOT)
diff_colored = cv2.cvtColor(diff_colored, cv2.COLOR_BGR2RGB)
else:
diff_colored = difference_image
# Overlay difference image
combined_image = self.image_processor.overlay_images(combined_image, diff_colored, alpha_diff)
# 2. Overlay low threat heatmap if available
if 'low' in multi_heatmaps:
low_heatmap = multi_heatmaps['low']
combined_image = self.image_processor.overlay_images(combined_image, low_heatmap, alpha_low)
# 3. Overlay medium threat heatmap if available
if 'medium' in multi_heatmaps:
medium_heatmap = multi_heatmaps['medium']
combined_image = self.image_processor.overlay_images(combined_image, medium_heatmap, alpha_medium)
# 4. Draw bounding boxes with threat level colors
threat_colors = {
'low': (0, 255, 0), # Green
'medium': (0, 165, 255), # Orange
'high': (0, 0, 255) # Red
}
# Draw bounding boxes based on threat levels
for region in labeled_regions:
bbox = region['bbox']
threat_level = region['threat_level']
x, y, w, h = bbox
# Get color for this threat level (default to red if not found)
color = threat_colors.get(threat_level, (0, 0, 255))
# Convert BGR to RGB for matplotlib
color_rgb = (color[2]/255, color[1]/255, color[0]/255)
# Draw rectangle with threat level color
cv2.rectangle(combined_image, (x, y), (x + w, y + h), color, 2)
# Add label text with threat level
if 'difference_percentage' in region:
label_text = f"{threat_level.upper()}: {region['difference_percentage']:.1f}%"
else:
label_text = f"{threat_level.upper()}"
cv2.putText(combined_image, label_text, (x, y - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
# Save the combined visualization
plt.figure(figsize=(12, 8))
plt.imshow(combined_image)
plt.title('Combined Threat Visualization')
plt.axis('off')
plt.tight_layout()
plt.savefig(output_path, dpi=dpi, bbox_inches='tight')
plt.close()
# Also save the raw image for potential further processing
raw_output_path = output_path.replace('.png', '_raw.png')
self.image_processor.save_image(combined_image, raw_output_path)
return output_path
def create_combined_visualization_from_files(self, original_path, difference_path,
low_heatmap_path, medium_heatmap_path,
bounding_boxes, output_path,
alpha_diff=0.4, alpha_low=0.3, alpha_medium=0.4):
"""
Create a combined visualization from individual image files
Args:
original_path: Path to original image
difference_path: Path to difference image
low_heatmap_path: Path to low threat heatmap
medium_heatmap_path: Path to medium threat heatmap
bounding_boxes: List of bounding boxes (x, y, w, h)
output_path: Path to save the visualization
alpha_diff: Transparency for difference image overlay
alpha_low: Transparency for low threat heatmap overlay
alpha_medium: Transparency for medium threat heatmap overlay
Returns:
Path to the generated visualization
"""
# Load images
original_image = self.image_processor.load_image(original_path)
difference_image = self.image_processor.load_image(difference_path)
# Load heatmaps if paths are provided
low_heatmap = None
medium_heatmap = None
if low_heatmap_path:
low_heatmap = self.image_processor.load_image(low_heatmap_path)
if medium_heatmap_path:
medium_heatmap = self.image_processor.load_image(medium_heatmap_path)
# Create a mock image_pair_results dictionary
image_pair_results = {
'original_image': original_image,
'difference_image': difference_image,
'bounding_boxes': bounding_boxes,
'multi_heatmaps': {}
}
if low_heatmap is not None:
image_pair_results['multi_heatmaps']['low'] = low_heatmap
if medium_heatmap is not None:
image_pair_results['multi_heatmaps']['medium'] = medium_heatmap
# Call the main visualization method
return self.create_combined_visualization(
image_pair_results, output_path, alpha_diff, alpha_low, alpha_medium
)
# Example usage
if __name__ == "__main__":
import os
from deepfake_detector import DeepfakeDetector
from labeling import ThreatLabeler
# Initialize components
detector = DeepfakeDetector()
labeler = ThreatLabeler()
heatmap_gen = HeatmapGenerator()
img_processor = ImageProcessor()
visualizer = CombinedVisualizer()
# Example paths
image1_path = "path/to/original.jpg"
image2_path = "path/to/modified.jpg"
output_dir = "path/to/output"
# Ensure output directory exists
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Process images
results = detector.process_image_pair(image1_path, image2_path)
# Label regions
original_image = img_processor.load_image(image1_path)
labeled_image, labeled_regions = labeler.label_regions(
original_image, results['difference_image'], results['bounding_boxes']
)
# Generate multi-level heatmaps
multi_heatmaps = heatmap_gen.generate_multi_level_heatmap(original_image, labeled_regions)
# Prepare results for combined visualization
image_pair_results = {
'original_image': original_image,
'difference_image': results['difference_image'],
'bounding_boxes': results['bounding_boxes'],
'multi_heatmaps': multi_heatmaps,
'labeled_regions': labeled_regions
}
# Create combined visualization
output_path = os.path.join(output_dir, "combined_visualization.png")
visualizer.create_combined_visualization(image_pair_results, output_path)
print(f"Combined visualization saved to: {output_path}")