heatmap / compare_images.py
noumanjavaid's picture
Upload 11 files
9e629a3 verified
import os
import sys
from deepfake_detector import DeepfakeDetector
from image_processor import ImageProcessor
from labeling import ThreatLabeler
from heatmap_generator import HeatmapGenerator
from comparison_interface import ComparisonInterface
def ensure_dir(directory):
if not os.path.exists(directory):
os.makedirs(directory)
def compare_images(image1_path, image2_path, output_dir, threshold=30, min_area=100):
# Initialize components
detector = DeepfakeDetector()
labeler = ThreatLabeler()
heatmap_gen = HeatmapGenerator()
img_processor = ImageProcessor()
comparison = ComparisonInterface()
# Create output directory
ensure_dir(output_dir)
# Get base filename for outputs
base_name = os.path.splitext(os.path.basename(image1_path))[0]
print(f"Processing images: {image1_path} and {image2_path}")
# Step 1: Verification Module - Process the image pair
results = detector.process_image_pair(image1_path, image2_path, threshold, min_area)
# Step 2: Labeling System - Label detected regions by threat level
original_image = img_processor.load_image(image1_path)
modified_image = img_processor.load_image(image2_path)
labeled_image, labeled_regions = labeler.label_regions(
original_image, results['difference_image'], results['bounding_boxes'])
# Get threat summary
threat_summary = labeler.get_threat_summary(labeled_regions)
# Step 3: Heatmap Visualization - Generate heatmaps for threat visualization
# Generate standard heatmap
heatmap = heatmap_gen.generate_threat_heatmap(original_image, labeled_regions)
# Generate multi-level heatmaps
multi_heatmaps = heatmap_gen.generate_multi_level_heatmap(original_image, labeled_regions)
# Prepare results for comparison interface
image_pair_results = {
'original_image': original_image,
'modified_image': modified_image,
'difference_image': results['difference_image'],
'threshold_image': results['threshold_image'],
'annotated_image': results['annotated_image'],
'labeled_image': labeled_image,
'heatmap_overlay': heatmap,
'multi_heatmaps': multi_heatmaps,
'threat_summary': threat_summary,
'smi_score': results['smi_score'],
'bounding_boxes': results['bounding_boxes']
}
# Create comprehensive visualization
output_path = os.path.join(output_dir, f"{base_name}_comparison.png")
comparison.create_comparison_grid(image_pair_results, output_path)
# Print summary information
print(f"\nAnalysis Results:")
print(f"SMI Score: {results['smi_score']:.4f} (1.0 = identical, 0.0 = completely different)")
print(f"Total regions detected: {threat_summary['total_regions']}")
print(f"Threat counts: Low={threat_summary['threat_counts']['low']}, "
f"Medium={threat_summary['threat_counts']['medium']}, "
f"High={threat_summary['threat_counts']['high']}")
if threat_summary['max_threat']:
print(f"Maximum threat: {threat_summary['max_threat']['level'].upper()} "
f"({threat_summary['max_threat']['percentage']:.1f}%)")
print(f"Average difference: {threat_summary['average_difference']:.1f}%")
print(f"\nComparison visualization saved to: {output_path}")
return output_path
def main():
if len(sys.argv) < 3:
print("Usage: python compare_images.py <image1_path> <image2_path> [output_dir]")
sys.exit(1)
image1_path = sys.argv[1]
image2_path = sys.argv[2]
# Use default output directory if not specified
output_dir = sys.argv[3] if len(sys.argv) > 3 else "./comparison_output"
compare_images(image1_path, image2_path, output_dir)
if __name__ == "__main__":
main()