File size: 7,151 Bytes
9e629a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import os
import argparse
import cv2
import numpy as np
import matplotlib.pyplot as plt
from deepfake_detector import DeepfakeDetector
from image_processor import ImageProcessor
from labeling import ThreatLabeler
from heatmap_generator import HeatmapGenerator

def parse_args():
    parser = argparse.ArgumentParser(description='Deepfake Detection with Heatmap Visualization')
    parser.add_argument('--input_dir', type=str, required=True, help='Directory containing input images')
    parser.add_argument('--output_dir', type=str, required=True, help='Directory to save output visualizations')
    parser.add_argument('--model_path', type=str, help='Path to Nvidia AI model (optional)')
    parser.add_argument('--threshold', type=int, default=30, help='Threshold for difference detection (0-255)')
    parser.add_argument('--min_area', type=int, default=100, help='Minimum area for region detection')
    return parser.parse_args()

def ensure_dir(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)

def process_image_pair(image1_path, image2_path, output_dir, model_path=None, threshold=30, min_area=100):
    # Initialize components
    detector = DeepfakeDetector(model_path)
    labeler = ThreatLabeler()
    heatmap_gen = HeatmapGenerator()
    img_processor = ImageProcessor()
    
    # Create output subdirectories
    verification_dir = os.path.join(output_dir, 'verification')
    labeling_dir = os.path.join(output_dir, 'labeling')
    heatmap_dir = os.path.join(output_dir, 'heatmap')
    ensure_dir(verification_dir)
    ensure_dir(labeling_dir)
    ensure_dir(heatmap_dir)
    
    # Get base filename for outputs
    base_name = os.path.splitext(os.path.basename(image1_path))[0]
    
    # Step 1: Verification Module - Process the image pair
    print(f"Processing images: {image1_path} and {image2_path}")
    results = detector.process_image_pair(image1_path, image2_path, threshold, min_area)
    
    # Save verification results
    img_processor.save_image(results['difference_image'], 
                           os.path.join(verification_dir, f"{base_name}_diff.png"))
    img_processor.save_image(results['threshold_image'], 
                           os.path.join(verification_dir, f"{base_name}_threshold.png"))
    img_processor.save_image(results['annotated_image'], 
                           os.path.join(verification_dir, f"{base_name}_annotated.png"))
    
    # Print SMI score
    print(f"SMI Score: {results['smi_score']:.4f} (1.0 = identical, 0.0 = completely different)")
    
    # Step 2: Labeling System - Label detected regions by threat level
    original_image = img_processor.load_image(image1_path)
    labeled_image, labeled_regions = labeler.label_regions(
        original_image, results['difference_image'], results['bounding_boxes'])
    
    # Save labeled image
    img_processor.save_image(labeled_image, 
                           os.path.join(labeling_dir, f"{base_name}_labeled.png"))
    
    # Get threat summary
    threat_summary = labeler.get_threat_summary(labeled_regions)
    print("\nThreat Summary:")
    print(f"Total regions detected: {threat_summary['total_regions']}")
    print(f"Threat counts: Low={threat_summary['threat_counts']['low']}, "
          f"Medium={threat_summary['threat_counts']['medium']}, "
          f"High={threat_summary['threat_counts']['high']}")
    if threat_summary['max_threat']:
        print(f"Maximum threat: {threat_summary['max_threat']['level'].upper()} "
              f"({threat_summary['max_threat']['percentage']:.1f}%)")
    print(f"Average difference: {threat_summary['average_difference']:.1f}%")
    
    # Step 3: Heatmap Visualization - Generate heatmaps for threat visualization
    # Generate standard heatmap
    heatmap = heatmap_gen.generate_threat_heatmap(original_image, labeled_regions)
    img_processor.save_image(heatmap, 
                           os.path.join(heatmap_dir, f"{base_name}_heatmap.png"))
    
    # Generate multi-level heatmaps
    multi_heatmaps = heatmap_gen.generate_multi_level_heatmap(original_image, labeled_regions)
    
    # Save multi-level heatmaps
    for level, hmap in multi_heatmaps.items():
        if level != 'overlay':  # 'overlay' is already saved above
            img_processor.save_image(hmap, 
                                   os.path.join(heatmap_dir, f"{base_name}_heatmap_{level}.png"))
    
    # Save side-by-side visualization
    heatmap_gen.save_heatmap_visualization(
        original_image, multi_heatmaps['overlay'],
        os.path.join(output_dir, f"{base_name}_visualization.png")
    )
    
    print(f"\nProcessing complete. Results saved to {output_dir}")
    return threat_summary

def main():
    args = parse_args()
    
    # Ensure output directory exists
    ensure_dir(args.output_dir)
    
    # Get all image files in input directory
    image_files = [f for f in os.listdir(args.input_dir) 
                  if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    
    # Group images for comparison (assuming pairs with _original and _modified suffixes)
    # This is a simple example - you might need to adjust based on your naming convention
    original_images = [f for f in image_files if '_original' in f]
    modified_images = [f for f in image_files if '_modified' in f]
    
    # If we don't have clear pairs, just process consecutive images
    if not (original_images and modified_images):
        # Process images in pairs (1&2, 3&4, etc.)
        image_pairs = [(image_files[i], image_files[i+1]) 
                      for i in range(0, len(image_files)-1, 2)]
    else:
        # Match original and modified pairs
        image_pairs = []
        for orig in original_images:
            base_name = orig.replace('_original', '')
            for mod in modified_images:
                if base_name in mod:
                    image_pairs.append((orig, mod))
                    break
    
    # Process each image pair
    results = []
    for img1, img2 in image_pairs:
        img1_path = os.path.join(args.input_dir, img1)
        img2_path = os.path.join(args.input_dir, img2)
        
        print(f"\n{'='*50}")
        print(f"Processing pair: {img1} and {img2}")
        print(f"{'='*50}")
        
        result = process_image_pair(
            img1_path, img2_path, args.output_dir,
            args.model_path, args.threshold, args.min_area
        )
        results.append({
            'pair': (img1, img2),
            'summary': result
        })
    
    # Print overall summary
    print(f"\n{'='*50}")
    print(f"Overall Summary: Processed {len(image_pairs)} image pairs")
    print(f"{'='*50}")
    
    high_threat_pairs = [r for r in results 
                        if r['summary']['threat_counts']['high'] > 0]
    print(f"Pairs with high threats: {len(high_threat_pairs)} / {len(results)}")
    
    if high_threat_pairs:
        print("\nHigh threat pairs:")
        for r in high_threat_pairs:
            print(f"- {r['pair'][0]} and {r['pair'][1]}: "
                  f"{r['summary']['threat_counts']['high']} high threat regions")

if __name__ == "__main__":
    main()