heatmap / main.py
noumanjavaid's picture
Upload 11 files
9e629a3 verified
raw
history blame
7.15 kB
import os
import argparse
import cv2
import numpy as np
import matplotlib.pyplot as plt
from deepfake_detector import DeepfakeDetector
from image_processor import ImageProcessor
from labeling import ThreatLabeler
from heatmap_generator import HeatmapGenerator
def parse_args():
parser = argparse.ArgumentParser(description='Deepfake Detection with Heatmap Visualization')
parser.add_argument('--input_dir', type=str, required=True, help='Directory containing input images')
parser.add_argument('--output_dir', type=str, required=True, help='Directory to save output visualizations')
parser.add_argument('--model_path', type=str, help='Path to Nvidia AI model (optional)')
parser.add_argument('--threshold', type=int, default=30, help='Threshold for difference detection (0-255)')
parser.add_argument('--min_area', type=int, default=100, help='Minimum area for region detection')
return parser.parse_args()
def ensure_dir(directory):
if not os.path.exists(directory):
os.makedirs(directory)
def process_image_pair(image1_path, image2_path, output_dir, model_path=None, threshold=30, min_area=100):
# Initialize components
detector = DeepfakeDetector(model_path)
labeler = ThreatLabeler()
heatmap_gen = HeatmapGenerator()
img_processor = ImageProcessor()
# Create output subdirectories
verification_dir = os.path.join(output_dir, 'verification')
labeling_dir = os.path.join(output_dir, 'labeling')
heatmap_dir = os.path.join(output_dir, 'heatmap')
ensure_dir(verification_dir)
ensure_dir(labeling_dir)
ensure_dir(heatmap_dir)
# Get base filename for outputs
base_name = os.path.splitext(os.path.basename(image1_path))[0]
# Step 1: Verification Module - Process the image pair
print(f"Processing images: {image1_path} and {image2_path}")
results = detector.process_image_pair(image1_path, image2_path, threshold, min_area)
# Save verification results
img_processor.save_image(results['difference_image'],
os.path.join(verification_dir, f"{base_name}_diff.png"))
img_processor.save_image(results['threshold_image'],
os.path.join(verification_dir, f"{base_name}_threshold.png"))
img_processor.save_image(results['annotated_image'],
os.path.join(verification_dir, f"{base_name}_annotated.png"))
# Print SMI score
print(f"SMI Score: {results['smi_score']:.4f} (1.0 = identical, 0.0 = completely different)")
# Step 2: Labeling System - Label detected regions by threat level
original_image = img_processor.load_image(image1_path)
labeled_image, labeled_regions = labeler.label_regions(
original_image, results['difference_image'], results['bounding_boxes'])
# Save labeled image
img_processor.save_image(labeled_image,
os.path.join(labeling_dir, f"{base_name}_labeled.png"))
# Get threat summary
threat_summary = labeler.get_threat_summary(labeled_regions)
print("\nThreat Summary:")
print(f"Total regions detected: {threat_summary['total_regions']}")
print(f"Threat counts: Low={threat_summary['threat_counts']['low']}, "
f"Medium={threat_summary['threat_counts']['medium']}, "
f"High={threat_summary['threat_counts']['high']}")
if threat_summary['max_threat']:
print(f"Maximum threat: {threat_summary['max_threat']['level'].upper()} "
f"({threat_summary['max_threat']['percentage']:.1f}%)")
print(f"Average difference: {threat_summary['average_difference']:.1f}%")
# Step 3: Heatmap Visualization - Generate heatmaps for threat visualization
# Generate standard heatmap
heatmap = heatmap_gen.generate_threat_heatmap(original_image, labeled_regions)
img_processor.save_image(heatmap,
os.path.join(heatmap_dir, f"{base_name}_heatmap.png"))
# Generate multi-level heatmaps
multi_heatmaps = heatmap_gen.generate_multi_level_heatmap(original_image, labeled_regions)
# Save multi-level heatmaps
for level, hmap in multi_heatmaps.items():
if level != 'overlay': # 'overlay' is already saved above
img_processor.save_image(hmap,
os.path.join(heatmap_dir, f"{base_name}_heatmap_{level}.png"))
# Save side-by-side visualization
heatmap_gen.save_heatmap_visualization(
original_image, multi_heatmaps['overlay'],
os.path.join(output_dir, f"{base_name}_visualization.png")
)
print(f"\nProcessing complete. Results saved to {output_dir}")
return threat_summary
def main():
args = parse_args()
# Ensure output directory exists
ensure_dir(args.output_dir)
# Get all image files in input directory
image_files = [f for f in os.listdir(args.input_dir)
if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
# Group images for comparison (assuming pairs with _original and _modified suffixes)
# This is a simple example - you might need to adjust based on your naming convention
original_images = [f for f in image_files if '_original' in f]
modified_images = [f for f in image_files if '_modified' in f]
# If we don't have clear pairs, just process consecutive images
if not (original_images and modified_images):
# Process images in pairs (1&2, 3&4, etc.)
image_pairs = [(image_files[i], image_files[i+1])
for i in range(0, len(image_files)-1, 2)]
else:
# Match original and modified pairs
image_pairs = []
for orig in original_images:
base_name = orig.replace('_original', '')
for mod in modified_images:
if base_name in mod:
image_pairs.append((orig, mod))
break
# Process each image pair
results = []
for img1, img2 in image_pairs:
img1_path = os.path.join(args.input_dir, img1)
img2_path = os.path.join(args.input_dir, img2)
print(f"\n{'='*50}")
print(f"Processing pair: {img1} and {img2}")
print(f"{'='*50}")
result = process_image_pair(
img1_path, img2_path, args.output_dir,
args.model_path, args.threshold, args.min_area
)
results.append({
'pair': (img1, img2),
'summary': result
})
# Print overall summary
print(f"\n{'='*50}")
print(f"Overall Summary: Processed {len(image_pairs)} image pairs")
print(f"{'='*50}")
high_threat_pairs = [r for r in results
if r['summary']['threat_counts']['high'] > 0]
print(f"Pairs with high threats: {len(high_threat_pairs)} / {len(results)}")
if high_threat_pairs:
print("\nHigh threat pairs:")
for r in high_threat_pairs:
print(f"- {r['pair'][0]} and {r['pair'][1]}: "
f"{r['summary']['threat_counts']['high']} high threat regions")
if __name__ == "__main__":
main()