heatmap / comparison_interface.py
noumanjavaid's picture
Upload 11 files
9e629a3 verified
raw
history blame
18 kB
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from deepfake_detector import DeepfakeDetector
from image_processor import ImageProcessor
from labeling import ThreatLabeler
from heatmap_generator import HeatmapGenerator
class ComparisonInterface:
def __init__(self):
"""
Initialize the comparison interface for visualizing all processing stages
"""
self.img_processor = ImageProcessor()
def create_comparison_grid(self, image_pair_results, output_path, figsize=(18, 12), dpi=300):
"""
Create a comprehensive grid visualization of all processing stages
Args:
image_pair_results: Dictionary containing all processing results
output_path: Path to save the visualization
figsize: Figure size (width, height) in inches
dpi: Resolution for saved image
"""
# Extract images from results
original_image = image_pair_results['original_image']
modified_image = image_pair_results['modified_image']
difference_image = image_pair_results['difference_image']
threshold_image = image_pair_results['threshold_image']
annotated_image = image_pair_results['annotated_image']
labeled_image = image_pair_results['labeled_image']
heatmap_overlay = image_pair_results['heatmap_overlay']
# Extract multi-level heatmaps if available
multi_heatmaps = image_pair_results.get('multi_heatmaps', {})
# Create figure with grid layout
fig = plt.figure(figsize=figsize)
gs = GridSpec(3, 4, figure=fig)
# Row 1: Original images and difference
ax1 = fig.add_subplot(gs[0, 0])
ax1.imshow(original_image)
ax1.set_title('Original Image')
ax1.axis('off')
ax2 = fig.add_subplot(gs[0, 1])
ax2.imshow(modified_image)
ax2.set_title('Modified Image')
ax2.axis('off')
ax3 = fig.add_subplot(gs[0, 2])
ax3.imshow(difference_image, cmap='gray')
ax3.set_title('Difference Image')
ax3.axis('off')
ax4 = fig.add_subplot(gs[0, 3])
ax4.imshow(threshold_image, cmap='gray')
ax4.set_title('Thresholded Difference')
ax4.axis('off')
# Row 2: Annotated, labeled, and heatmap
ax5 = fig.add_subplot(gs[1, 0:2])
ax5.imshow(annotated_image)
ax5.set_title('Detected Regions')
ax5.axis('off')
ax6 = fig.add_subplot(gs[1, 2:4])
ax6.imshow(labeled_image)
ax6.set_title('Threat Labeled Regions')
ax6.axis('off')
# Row 3: Multi-level heatmaps
if 'low' in multi_heatmaps and 'medium' in multi_heatmaps and 'high' in multi_heatmaps:
ax7 = fig.add_subplot(gs[2, 0])
ax7.imshow(multi_heatmaps['low'])
ax7.set_title('Low Threat Heatmap')
ax7.axis('off')
ax8 = fig.add_subplot(gs[2, 1])
ax8.imshow(multi_heatmaps['medium'])
ax8.set_title('Medium Threat Heatmap')
ax8.axis('off')
ax9 = fig.add_subplot(gs[2, 2])
ax9.imshow(multi_heatmaps['high'])
ax9.set_title('High Threat Heatmap')
ax9.axis('off')
else:
# If multi-level heatmaps not available, show combined heatmap in larger space
ax7 = fig.add_subplot(gs[2, 0:3])
ax7.imshow(heatmap_overlay)
ax7.set_title('Combined Threat Heatmap')
ax7.axis('off')
# Add threat summary in text box
ax10 = fig.add_subplot(gs[2, 3])
ax10.axis('off')
summary_text = self._format_summary_text(image_pair_results['threat_summary'], image_pair_results['smi_score'])
ax10.text(0, 0.5, summary_text, fontsize=10, va='center', ha='left', wrap=True)
ax10.set_title('Threat Summary')
# Add overall title
plt.suptitle(f"Deepfake Detection Analysis", fontsize=16)
# Adjust layout and save
plt.tight_layout(rect=[0, 0, 1, 0.97])
plt.savefig(output_path, dpi=dpi, bbox_inches='tight')
plt.close()
return output_path
def _format_summary_text(self, threat_summary, smi_score):
"""
Format threat summary as text for display
"""
text = f"SMI Score: {smi_score:.4f}\n"
text += f"(1.0 = identical, 0.0 = different)\n\n"
text += f"Total regions: {threat_summary['total_regions']}\n\n"
text += f"Threat counts:\n"
text += f" Low: {threat_summary['threat_counts']['low']}\n"
text += f" Medium: {threat_summary['threat_counts']['medium']}\n"
text += f" High: {threat_summary['threat_counts']['high']}\n\n"
if threat_summary['max_threat']:
text += f"Maximum threat: {threat_summary['max_threat']['level'].upper()}\n"
text += f" ({threat_summary['max_threat']['percentage']:.1f}%)\n\n"
text += f"Average difference: {threat_summary['average_difference']:.1f}%"
return text
def create_interactive_comparison(self, image_pair_results, output_path):
"""
Create an HTML file with interactive comparison of all processing stages
Args:
image_pair_results: Dictionary containing all processing results
output_path: Path to save the HTML file
Returns:
Path to the generated HTML file
"""
# Create output directory for individual images
output_dir = os.path.dirname(output_path)
images_dir = os.path.join(output_dir, 'images')
if not os.path.exists(images_dir):
os.makedirs(images_dir)
# Get base filename for outputs
base_name = os.path.basename(output_path).split('.')[0]
# Save individual images for HTML display
image_paths = {}
# Save original and modified images
original_path = os.path.join(images_dir, f"{base_name}_original.png")
modified_path = os.path.join(images_dir, f"{base_name}_modified.png")
self.img_processor.save_image(image_pair_results['original_image'], original_path)
self.img_processor.save_image(image_pair_results['modified_image'], modified_path)
image_paths['original_image_path'] = os.path.relpath(original_path, output_dir)
image_paths['modified_image_path'] = os.path.relpath(modified_path, output_dir)
# Save difference and threshold images
difference_path = os.path.join(images_dir, f"{base_name}_difference.png")
threshold_path = os.path.join(images_dir, f"{base_name}_threshold.png")
self.img_processor.save_image(image_pair_results['difference_image'], difference_path)
self.img_processor.save_image(image_pair_results['threshold_image'], threshold_path)
image_paths['difference_image_path'] = os.path.relpath(difference_path, output_dir)
image_paths['threshold_image_path'] = os.path.relpath(threshold_path, output_dir)
# Save annotated and labeled images
annotated_path = os.path.join(images_dir, f"{base_name}_annotated.png")
labeled_path = os.path.join(images_dir, f"{base_name}_labeled.png")
self.img_processor.save_image(image_pair_results['annotated_image'], annotated_path)
self.img_processor.save_image(image_pair_results['labeled_image'], labeled_path)
image_paths['annotated_image_path'] = os.path.relpath(annotated_path, output_dir)
image_paths['labeled_image_path'] = os.path.relpath(labeled_path, output_dir)
# Save heatmap overlay
heatmap_path = os.path.join(images_dir, f"{base_name}_heatmap.png")
self.img_processor.save_image(image_pair_results['heatmap_overlay'], heatmap_path)
image_paths['heatmap_overlay_path'] = os.path.relpath(heatmap_path, output_dir)
# Save multi-level heatmaps if available
multi_heatmaps = image_pair_results.get('multi_heatmaps', {})
if 'low' in multi_heatmaps and 'medium' in multi_heatmaps and 'high' in multi_heatmaps:
low_path = os.path.join(images_dir, f"{base_name}_heatmap_low.png")
medium_path = os.path.join(images_dir, f"{base_name}_heatmap_medium.png")
high_path = os.path.join(images_dir, f"{base_name}_heatmap_high.png")
self.img_processor.save_image(multi_heatmaps['low'], low_path)
self.img_processor.save_image(multi_heatmaps['medium'], medium_path)
self.img_processor.save_image(multi_heatmaps['high'], high_path)
image_paths['low_heatmap_path'] = os.path.relpath(low_path, output_dir)
image_paths['medium_heatmap_path'] = os.path.relpath(medium_path, output_dir)
image_paths['high_heatmap_path'] = os.path.relpath(high_path, output_dir)
# Format threat summary for HTML display
threat_summary_text = self._format_summary_text(
image_pair_results['threat_summary'],
image_pair_results['smi_score']
)
# Read HTML template
template_path = os.path.join(os.path.dirname(__file__), 'templates', 'interactive_comparison.html')
with open(template_path, 'r') as f:
html_template = f.read()
# Replace placeholders with actual values
for key, value in image_paths.items():
html_template = html_template.replace(f"{{{{{key}}}}}", value)
# Replace threat summary
html_template = html_template.replace("{{threat_summary}}", threat_summary_text)
# Write HTML file
with open(output_path, 'w') as f:
f.write(html_template)
print(f"Interactive comparison saved to: {output_path}")
return output_path
def process_and_visualize(self, image1_path, image2_path, output_dir, model_path=None, threshold=30, min_area=100):
"""
Process an image pair and create comprehensive visualization
Args:
image1_path: Path to first image
image2_path: Path to second image
output_dir: Directory to save outputs
model_path: Path to AI model (optional)
threshold: Threshold for difference detection
min_area: Minimum area for region detection
Returns:
Path to the generated comparison visualization
"""
# Initialize components
detector = DeepfakeDetector(model_path)
labeler = ThreatLabeler()
heatmap_gen = HeatmapGenerator()
# Create output directory
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Get base filename for outputs
base_name = os.path.splitext(os.path.basename(image1_path))[0]
# Step 1: Verification Module - Process the image pair
print(f"Processing images: {image1_path} and {image2_path}")
detection_results = detector.process_image_pair(image1_path, image2_path, threshold, min_area)
# Step 2: Labeling System - Label detected regions by threat level
original_image = self.img_processor.load_image(image1_path)
modified_image = self.img_processor.load_image(image2_path)
labeled_image, labeled_regions = labeler.label_regions(
original_image, detection_results['difference_image'], detection_results['bounding_boxes'])
# Get threat summary
threat_summary = labeler.get_threat_summary(labeled_regions)
# Step 3: Heatmap Visualization - Generate heatmaps for threat visualization
heatmap_overlay = heatmap_gen.generate_threat_heatmap(original_image, labeled_regions)
multi_heatmaps = heatmap_gen.generate_multi_level_heatmap(original_image, labeled_regions)
# Combine all results
all_results = {
'original_image': original_image,
'modified_image': modified_image,
'difference_image': detection_results['difference_image'],
'threshold_image': detection_results['threshold_image'],
'annotated_image': detection_results['annotated_image'],
'labeled_image': labeled_image,
'heatmap_overlay': heatmap_overlay,
'multi_heatmaps': multi_heatmaps,
'threat_summary': threat_summary,
'smi_score': detection_results['smi_score'],
'bounding_boxes': detection_results['bounding_boxes']
}
# Create and save comparison visualization
grid_output_path = os.path.join(output_dir, f"{base_name}_comparison.png")
self.create_comparison_grid(all_results, grid_output_path)
# Create interactive HTML comparison
html_output_path = os.path.join(output_dir, f"{base_name}_interactive.html")
self.create_interactive_comparison(all_results, html_output_path)
print(f"Comparison visualization saved to: {grid_output_path}")
print(f"Interactive HTML comparison saved to: {html_output_path}")
return html_output_path # Return the interactive HTML path as the primary output
def batch_process_directory(input_dir, output_dir, model_path=None, threshold=30, min_area=100):
"""
Process all image pairs in a directory and create comparison visualizations
Args:
input_dir: Directory containing input images
output_dir: Directory to save outputs
model_path: Path to AI model (optional)
threshold: Threshold for difference detection
min_area: Minimum area for region detection
Returns:
List of paths to generated HTML comparison files
"""
# Ensure output directory exists
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Get all image files in input directory
image_files = [f for f in os.listdir(input_dir)
if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
if not image_files:
print(f"No image files found in {input_dir}")
return []
# Group images for comparison (assuming pairs with _original and _modified suffixes)
original_images = [f for f in image_files if '_original' in f]
modified_images = [f for f in image_files if '_modified' in f]
# If we don't have clear pairs, just process consecutive images
if not (original_images and modified_images):
# Process images in pairs (1&2, 3&4, etc.)
if len(image_files) < 2:
print("Need at least 2 images to compare")
return []
image_pairs = [(image_files[i], image_files[i+1])
for i in range(0, len(image_files)-1, 2)]
print(f"No _original/_modified naming pattern found. Processing {len(image_pairs)} consecutive pairs.")
else:
# Match original and modified pairs
image_pairs = []
for orig in original_images:
base_name = orig.replace('_original', '')
for mod in modified_images:
if base_name in mod:
image_pairs.append((orig, mod))
break
print(f"Found {len(image_pairs)} original/modified image pairs.")
if not image_pairs:
print("No valid image pairs found to process")
return []
# Initialize comparison interface
interface = ComparisonInterface()
# Process each image pair
html_paths = []
for img1, img2 in image_pairs:
img1_path = os.path.join(input_dir, img1)
img2_path = os.path.join(input_dir, img2)
print(f"\n{'='*50}")
print(f"Processing pair: {img1} and {img2}")
print(f"{'='*50}")
# Process and create comparison visualization
html_path = interface.process_and_visualize(
img1_path, img2_path, output_dir,
model_path, threshold, min_area
)
html_paths.append(html_path)
print(f"\n{'='*50}")
print(f"Overall Summary: Processed {len(image_pairs)} image pairs")
print(f"{'='*50}")
print(f"All comparison visualizations saved to: {output_dir}")
return html_paths
if __name__ == "__main__":
import argparse
import webbrowser
parser = argparse.ArgumentParser(description='Deepfake Detection Comparison Interface')
parser.add_argument('--input_dir', type=str, required=True, help='Directory containing input images')
parser.add_argument('--output_dir', type=str, required=True, help='Directory to save output visualizations')
parser.add_argument('--model_path', type=str, help='Path to Nvidia AI model (optional)')
parser.add_argument('--threshold', type=int, default=30, help='Threshold for difference detection (0-255)')
parser.add_argument('--min_area', type=int, default=100, help='Minimum area for region detection')
parser.add_argument('--open_browser', action='store_true', help='Automatically open HTML results in browser')
args = parser.parse_args()
# Process all images in the directory
html_paths = batch_process_directory(
args.input_dir, args.output_dir,
args.model_path, args.threshold, args.min_area
)
# Open the first result in browser if requested
if args.open_browser and html_paths:
print(f"\nOpening first result in web browser: {html_paths[0]}")
webbrowser.open('file://' + os.path.abspath(html_paths[0]))
print("\nTo view interactive results, open the HTML files in your web browser.")
print("Example: file://" + os.path.abspath(html_paths[0]) if html_paths else "")
print("\nDone!")