Spaces:
Sleeping
Sleeping
import os | |
import argparse | |
import cv2 | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from deepfake_detector import DeepfakeDetector | |
from image_processor import ImageProcessor | |
from labeling import ThreatLabeler | |
from heatmap_generator import HeatmapGenerator | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Deepfake Detection with Heatmap Visualization') | |
parser.add_argument('--input_dir', type=str, required=True, help='Directory containing input images') | |
parser.add_argument('--output_dir', type=str, required=True, help='Directory to save output visualizations') | |
parser.add_argument('--model_path', type=str, help='Path to Nvidia AI model (optional)') | |
parser.add_argument('--threshold', type=int, default=30, help='Threshold for difference detection (0-255)') | |
parser.add_argument('--min_area', type=int, default=100, help='Minimum area for region detection') | |
return parser.parse_args() | |
def ensure_dir(directory): | |
if not os.path.exists(directory): | |
os.makedirs(directory) | |
def process_image_pair(image1_path, image2_path, output_dir, model_path=None, threshold=30, min_area=100): | |
# Initialize components | |
detector = DeepfakeDetector(model_path) | |
labeler = ThreatLabeler() | |
heatmap_gen = HeatmapGenerator() | |
img_processor = ImageProcessor() | |
# Create output subdirectories | |
verification_dir = os.path.join(output_dir, 'verification') | |
labeling_dir = os.path.join(output_dir, 'labeling') | |
heatmap_dir = os.path.join(output_dir, 'heatmap') | |
ensure_dir(verification_dir) | |
ensure_dir(labeling_dir) | |
ensure_dir(heatmap_dir) | |
# Get base filename for outputs | |
base_name = os.path.splitext(os.path.basename(image1_path))[0] | |
# Step 1: Verification Module - Process the image pair | |
print(f"Processing images: {image1_path} and {image2_path}") | |
results = detector.process_image_pair(image1_path, image2_path, threshold, min_area) | |
# Save verification results | |
img_processor.save_image(results['difference_image'], | |
os.path.join(verification_dir, f"{base_name}_diff.png")) | |
img_processor.save_image(results['threshold_image'], | |
os.path.join(verification_dir, f"{base_name}_threshold.png")) | |
img_processor.save_image(results['annotated_image'], | |
os.path.join(verification_dir, f"{base_name}_annotated.png")) | |
# Print SMI score | |
print(f"SMI Score: {results['smi_score']:.4f} (1.0 = identical, 0.0 = completely different)") | |
# Step 2: Labeling System - Label detected regions by threat level | |
original_image = img_processor.load_image(image1_path) | |
labeled_image, labeled_regions = labeler.label_regions( | |
original_image, results['difference_image'], results['bounding_boxes']) | |
# Save labeled image | |
img_processor.save_image(labeled_image, | |
os.path.join(labeling_dir, f"{base_name}_labeled.png")) | |
# Get threat summary | |
threat_summary = labeler.get_threat_summary(labeled_regions) | |
print("\nThreat Summary:") | |
print(f"Total regions detected: {threat_summary['total_regions']}") | |
print(f"Threat counts: Low={threat_summary['threat_counts']['low']}, " | |
f"Medium={threat_summary['threat_counts']['medium']}, " | |
f"High={threat_summary['threat_counts']['high']}") | |
if threat_summary['max_threat']: | |
print(f"Maximum threat: {threat_summary['max_threat']['level'].upper()} " | |
f"({threat_summary['max_threat']['percentage']:.1f}%)") | |
print(f"Average difference: {threat_summary['average_difference']:.1f}%") | |
# Step 3: Heatmap Visualization - Generate heatmaps for threat visualization | |
# Generate standard heatmap | |
heatmap = heatmap_gen.generate_threat_heatmap(original_image, labeled_regions) | |
img_processor.save_image(heatmap, | |
os.path.join(heatmap_dir, f"{base_name}_heatmap.png")) | |
# Generate multi-level heatmaps | |
multi_heatmaps = heatmap_gen.generate_multi_level_heatmap(original_image, labeled_regions) | |
# Save multi-level heatmaps | |
for level, hmap in multi_heatmaps.items(): | |
if level != 'overlay': # 'overlay' is already saved above | |
img_processor.save_image(hmap, | |
os.path.join(heatmap_dir, f"{base_name}_heatmap_{level}.png")) | |
# Save side-by-side visualization | |
heatmap_gen.save_heatmap_visualization( | |
original_image, multi_heatmaps['overlay'], | |
os.path.join(output_dir, f"{base_name}_visualization.png") | |
) | |
print(f"\nProcessing complete. Results saved to {output_dir}") | |
return threat_summary | |
def main(): | |
args = parse_args() | |
# Ensure output directory exists | |
ensure_dir(args.output_dir) | |
# Get all image files in input directory | |
image_files = [f for f in os.listdir(args.input_dir) | |
if f.lower().endswith(('.png', '.jpg', '.jpeg'))] | |
# Group images for comparison (assuming pairs with _original and _modified suffixes) | |
# This is a simple example - you might need to adjust based on your naming convention | |
original_images = [f for f in image_files if '_original' in f] | |
modified_images = [f for f in image_files if '_modified' in f] | |
# If we don't have clear pairs, just process consecutive images | |
if not (original_images and modified_images): | |
# Process images in pairs (1&2, 3&4, etc.) | |
image_pairs = [(image_files[i], image_files[i+1]) | |
for i in range(0, len(image_files)-1, 2)] | |
else: | |
# Match original and modified pairs | |
image_pairs = [] | |
for orig in original_images: | |
base_name = orig.replace('_original', '') | |
for mod in modified_images: | |
if base_name in mod: | |
image_pairs.append((orig, mod)) | |
break | |
# Process each image pair | |
results = [] | |
for img1, img2 in image_pairs: | |
img1_path = os.path.join(args.input_dir, img1) | |
img2_path = os.path.join(args.input_dir, img2) | |
print(f"\n{'='*50}") | |
print(f"Processing pair: {img1} and {img2}") | |
print(f"{'='*50}") | |
result = process_image_pair( | |
img1_path, img2_path, args.output_dir, | |
args.model_path, args.threshold, args.min_area | |
) | |
results.append({ | |
'pair': (img1, img2), | |
'summary': result | |
}) | |
# Print overall summary | |
print(f"\n{'='*50}") | |
print(f"Overall Summary: Processed {len(image_pairs)} image pairs") | |
print(f"{'='*50}") | |
high_threat_pairs = [r for r in results | |
if r['summary']['threat_counts']['high'] > 0] | |
print(f"Pairs with high threats: {len(high_threat_pairs)} / {len(results)}") | |
if high_threat_pairs: | |
print("\nHigh threat pairs:") | |
for r in high_threat_pairs: | |
print(f"- {r['pair'][0]} and {r['pair'][1]}: " | |
f"{r['summary']['threat_counts']['high']} high threat regions") | |
if __name__ == "__main__": | |
main() |