import os import argparse import cv2 import numpy as np import matplotlib.pyplot as plt from deepfake_detector import DeepfakeDetector from image_processor import ImageProcessor from labeling import ThreatLabeler from heatmap_generator import HeatmapGenerator def parse_args(): parser = argparse.ArgumentParser(description='Deepfake Detection with Heatmap Visualization') parser.add_argument('--input_dir', type=str, required=True, help='Directory containing input images') parser.add_argument('--output_dir', type=str, required=True, help='Directory to save output visualizations') parser.add_argument('--model_path', type=str, help='Path to Nvidia AI model (optional)') parser.add_argument('--threshold', type=int, default=30, help='Threshold for difference detection (0-255)') parser.add_argument('--min_area', type=int, default=100, help='Minimum area for region detection') return parser.parse_args() def ensure_dir(directory): if not os.path.exists(directory): os.makedirs(directory) def process_image_pair(image1_path, image2_path, output_dir, model_path=None, threshold=30, min_area=100): # Initialize components detector = DeepfakeDetector(model_path) labeler = ThreatLabeler() heatmap_gen = HeatmapGenerator() img_processor = ImageProcessor() # Create output subdirectories verification_dir = os.path.join(output_dir, 'verification') labeling_dir = os.path.join(output_dir, 'labeling') heatmap_dir = os.path.join(output_dir, 'heatmap') ensure_dir(verification_dir) ensure_dir(labeling_dir) ensure_dir(heatmap_dir) # Get base filename for outputs base_name = os.path.splitext(os.path.basename(image1_path))[0] # Step 1: Verification Module - Process the image pair print(f"Processing images: {image1_path} and {image2_path}") results = detector.process_image_pair(image1_path, image2_path, threshold, min_area) # Save verification results img_processor.save_image(results['difference_image'], os.path.join(verification_dir, f"{base_name}_diff.png")) img_processor.save_image(results['threshold_image'], os.path.join(verification_dir, f"{base_name}_threshold.png")) img_processor.save_image(results['annotated_image'], os.path.join(verification_dir, f"{base_name}_annotated.png")) # Print SMI score print(f"SMI Score: {results['smi_score']:.4f} (1.0 = identical, 0.0 = completely different)") # Step 2: Labeling System - Label detected regions by threat level original_image = img_processor.load_image(image1_path) labeled_image, labeled_regions = labeler.label_regions( original_image, results['difference_image'], results['bounding_boxes']) # Save labeled image img_processor.save_image(labeled_image, os.path.join(labeling_dir, f"{base_name}_labeled.png")) # Get threat summary threat_summary = labeler.get_threat_summary(labeled_regions) print("\nThreat Summary:") print(f"Total regions detected: {threat_summary['total_regions']}") print(f"Threat counts: Low={threat_summary['threat_counts']['low']}, " f"Medium={threat_summary['threat_counts']['medium']}, " f"High={threat_summary['threat_counts']['high']}") if threat_summary['max_threat']: print(f"Maximum threat: {threat_summary['max_threat']['level'].upper()} " f"({threat_summary['max_threat']['percentage']:.1f}%)") print(f"Average difference: {threat_summary['average_difference']:.1f}%") # Step 3: Heatmap Visualization - Generate heatmaps for threat visualization # Generate standard heatmap heatmap = heatmap_gen.generate_threat_heatmap(original_image, labeled_regions) img_processor.save_image(heatmap, os.path.join(heatmap_dir, f"{base_name}_heatmap.png")) # Generate multi-level heatmaps multi_heatmaps = heatmap_gen.generate_multi_level_heatmap(original_image, labeled_regions) # Save multi-level heatmaps for level, hmap in multi_heatmaps.items(): if level != 'overlay': # 'overlay' is already saved above img_processor.save_image(hmap, os.path.join(heatmap_dir, f"{base_name}_heatmap_{level}.png")) # Save side-by-side visualization heatmap_gen.save_heatmap_visualization( original_image, multi_heatmaps['overlay'], os.path.join(output_dir, f"{base_name}_visualization.png") ) print(f"\nProcessing complete. Results saved to {output_dir}") return threat_summary def main(): args = parse_args() # Ensure output directory exists ensure_dir(args.output_dir) # Get all image files in input directory image_files = [f for f in os.listdir(args.input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] # Group images for comparison (assuming pairs with _original and _modified suffixes) # This is a simple example - you might need to adjust based on your naming convention original_images = [f for f in image_files if '_original' in f] modified_images = [f for f in image_files if '_modified' in f] # If we don't have clear pairs, just process consecutive images if not (original_images and modified_images): # Process images in pairs (1&2, 3&4, etc.) image_pairs = [(image_files[i], image_files[i+1]) for i in range(0, len(image_files)-1, 2)] else: # Match original and modified pairs image_pairs = [] for orig in original_images: base_name = orig.replace('_original', '') for mod in modified_images: if base_name in mod: image_pairs.append((orig, mod)) break # Process each image pair results = [] for img1, img2 in image_pairs: img1_path = os.path.join(args.input_dir, img1) img2_path = os.path.join(args.input_dir, img2) print(f"\n{'='*50}") print(f"Processing pair: {img1} and {img2}") print(f"{'='*50}") result = process_image_pair( img1_path, img2_path, args.output_dir, args.model_path, args.threshold, args.min_area ) results.append({ 'pair': (img1, img2), 'summary': result }) # Print overall summary print(f"\n{'='*50}") print(f"Overall Summary: Processed {len(image_pairs)} image pairs") print(f"{'='*50}") high_threat_pairs = [r for r in results if r['summary']['threat_counts']['high'] > 0] print(f"Pairs with high threats: {len(high_threat_pairs)} / {len(results)}") if high_threat_pairs: print("\nHigh threat pairs:") for r in high_threat_pairs: print(f"- {r['pair'][0]} and {r['pair'][1]}: " f"{r['summary']['threat_counts']['high']} high threat regions") if __name__ == "__main__": main()