Spaces:
Sleeping
Sleeping
File size: 7,151 Bytes
9e629a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import os
import argparse
import cv2
import numpy as np
import matplotlib.pyplot as plt
from deepfake_detector import DeepfakeDetector
from image_processor import ImageProcessor
from labeling import ThreatLabeler
from heatmap_generator import HeatmapGenerator
def parse_args():
parser = argparse.ArgumentParser(description='Deepfake Detection with Heatmap Visualization')
parser.add_argument('--input_dir', type=str, required=True, help='Directory containing input images')
parser.add_argument('--output_dir', type=str, required=True, help='Directory to save output visualizations')
parser.add_argument('--model_path', type=str, help='Path to Nvidia AI model (optional)')
parser.add_argument('--threshold', type=int, default=30, help='Threshold for difference detection (0-255)')
parser.add_argument('--min_area', type=int, default=100, help='Minimum area for region detection')
return parser.parse_args()
def ensure_dir(directory):
if not os.path.exists(directory):
os.makedirs(directory)
def process_image_pair(image1_path, image2_path, output_dir, model_path=None, threshold=30, min_area=100):
# Initialize components
detector = DeepfakeDetector(model_path)
labeler = ThreatLabeler()
heatmap_gen = HeatmapGenerator()
img_processor = ImageProcessor()
# Create output subdirectories
verification_dir = os.path.join(output_dir, 'verification')
labeling_dir = os.path.join(output_dir, 'labeling')
heatmap_dir = os.path.join(output_dir, 'heatmap')
ensure_dir(verification_dir)
ensure_dir(labeling_dir)
ensure_dir(heatmap_dir)
# Get base filename for outputs
base_name = os.path.splitext(os.path.basename(image1_path))[0]
# Step 1: Verification Module - Process the image pair
print(f"Processing images: {image1_path} and {image2_path}")
results = detector.process_image_pair(image1_path, image2_path, threshold, min_area)
# Save verification results
img_processor.save_image(results['difference_image'],
os.path.join(verification_dir, f"{base_name}_diff.png"))
img_processor.save_image(results['threshold_image'],
os.path.join(verification_dir, f"{base_name}_threshold.png"))
img_processor.save_image(results['annotated_image'],
os.path.join(verification_dir, f"{base_name}_annotated.png"))
# Print SMI score
print(f"SMI Score: {results['smi_score']:.4f} (1.0 = identical, 0.0 = completely different)")
# Step 2: Labeling System - Label detected regions by threat level
original_image = img_processor.load_image(image1_path)
labeled_image, labeled_regions = labeler.label_regions(
original_image, results['difference_image'], results['bounding_boxes'])
# Save labeled image
img_processor.save_image(labeled_image,
os.path.join(labeling_dir, f"{base_name}_labeled.png"))
# Get threat summary
threat_summary = labeler.get_threat_summary(labeled_regions)
print("\nThreat Summary:")
print(f"Total regions detected: {threat_summary['total_regions']}")
print(f"Threat counts: Low={threat_summary['threat_counts']['low']}, "
f"Medium={threat_summary['threat_counts']['medium']}, "
f"High={threat_summary['threat_counts']['high']}")
if threat_summary['max_threat']:
print(f"Maximum threat: {threat_summary['max_threat']['level'].upper()} "
f"({threat_summary['max_threat']['percentage']:.1f}%)")
print(f"Average difference: {threat_summary['average_difference']:.1f}%")
# Step 3: Heatmap Visualization - Generate heatmaps for threat visualization
# Generate standard heatmap
heatmap = heatmap_gen.generate_threat_heatmap(original_image, labeled_regions)
img_processor.save_image(heatmap,
os.path.join(heatmap_dir, f"{base_name}_heatmap.png"))
# Generate multi-level heatmaps
multi_heatmaps = heatmap_gen.generate_multi_level_heatmap(original_image, labeled_regions)
# Save multi-level heatmaps
for level, hmap in multi_heatmaps.items():
if level != 'overlay': # 'overlay' is already saved above
img_processor.save_image(hmap,
os.path.join(heatmap_dir, f"{base_name}_heatmap_{level}.png"))
# Save side-by-side visualization
heatmap_gen.save_heatmap_visualization(
original_image, multi_heatmaps['overlay'],
os.path.join(output_dir, f"{base_name}_visualization.png")
)
print(f"\nProcessing complete. Results saved to {output_dir}")
return threat_summary
def main():
args = parse_args()
# Ensure output directory exists
ensure_dir(args.output_dir)
# Get all image files in input directory
image_files = [f for f in os.listdir(args.input_dir)
if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
# Group images for comparison (assuming pairs with _original and _modified suffixes)
# This is a simple example - you might need to adjust based on your naming convention
original_images = [f for f in image_files if '_original' in f]
modified_images = [f for f in image_files if '_modified' in f]
# If we don't have clear pairs, just process consecutive images
if not (original_images and modified_images):
# Process images in pairs (1&2, 3&4, etc.)
image_pairs = [(image_files[i], image_files[i+1])
for i in range(0, len(image_files)-1, 2)]
else:
# Match original and modified pairs
image_pairs = []
for orig in original_images:
base_name = orig.replace('_original', '')
for mod in modified_images:
if base_name in mod:
image_pairs.append((orig, mod))
break
# Process each image pair
results = []
for img1, img2 in image_pairs:
img1_path = os.path.join(args.input_dir, img1)
img2_path = os.path.join(args.input_dir, img2)
print(f"\n{'='*50}")
print(f"Processing pair: {img1} and {img2}")
print(f"{'='*50}")
result = process_image_pair(
img1_path, img2_path, args.output_dir,
args.model_path, args.threshold, args.min_area
)
results.append({
'pair': (img1, img2),
'summary': result
})
# Print overall summary
print(f"\n{'='*50}")
print(f"Overall Summary: Processed {len(image_pairs)} image pairs")
print(f"{'='*50}")
high_threat_pairs = [r for r in results
if r['summary']['threat_counts']['high'] > 0]
print(f"Pairs with high threats: {len(high_threat_pairs)} / {len(results)}")
if high_threat_pairs:
print("\nHigh threat pairs:")
for r in high_threat_pairs:
print(f"- {r['pair'][0]} and {r['pair'][1]}: "
f"{r['summary']['threat_counts']['high']} high threat regions")
if __name__ == "__main__":
main() |