Spaces:
Sleeping
Sleeping
import os | |
import cv2 | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from matplotlib.gridspec import GridSpec | |
from deepfake_detector import DeepfakeDetector | |
from image_processor import ImageProcessor | |
from labeling import ThreatLabeler | |
from heatmap_generator import HeatmapGenerator | |
class ComparisonInterface: | |
def __init__(self): | |
""" | |
Initialize the comparison interface for visualizing all processing stages | |
""" | |
self.img_processor = ImageProcessor() | |
def create_comparison_grid(self, image_pair_results, output_path, figsize=(18, 12), dpi=300): | |
""" | |
Create a comprehensive grid visualization of all processing stages | |
Args: | |
image_pair_results: Dictionary containing all processing results | |
output_path: Path to save the visualization | |
figsize: Figure size (width, height) in inches | |
dpi: Resolution for saved image | |
""" | |
# Extract images from results | |
original_image = image_pair_results['original_image'] | |
modified_image = image_pair_results['modified_image'] | |
difference_image = image_pair_results['difference_image'] | |
threshold_image = image_pair_results['threshold_image'] | |
annotated_image = image_pair_results['annotated_image'] | |
labeled_image = image_pair_results['labeled_image'] | |
heatmap_overlay = image_pair_results['heatmap_overlay'] | |
# Extract multi-level heatmaps if available | |
multi_heatmaps = image_pair_results.get('multi_heatmaps', {}) | |
# Create figure with grid layout | |
fig = plt.figure(figsize=figsize) | |
gs = GridSpec(3, 4, figure=fig) | |
# Row 1: Original images and difference | |
ax1 = fig.add_subplot(gs[0, 0]) | |
ax1.imshow(original_image) | |
ax1.set_title('Original Image') | |
ax1.axis('off') | |
ax2 = fig.add_subplot(gs[0, 1]) | |
ax2.imshow(modified_image) | |
ax2.set_title('Modified Image') | |
ax2.axis('off') | |
ax3 = fig.add_subplot(gs[0, 2]) | |
ax3.imshow(difference_image, cmap='gray') | |
ax3.set_title('Difference Image') | |
ax3.axis('off') | |
ax4 = fig.add_subplot(gs[0, 3]) | |
ax4.imshow(threshold_image, cmap='gray') | |
ax4.set_title('Thresholded Difference') | |
ax4.axis('off') | |
# Row 2: Annotated, labeled, and heatmap | |
ax5 = fig.add_subplot(gs[1, 0:2]) | |
ax5.imshow(annotated_image) | |
ax5.set_title('Detected Regions') | |
ax5.axis('off') | |
ax6 = fig.add_subplot(gs[1, 2:4]) | |
ax6.imshow(labeled_image) | |
ax6.set_title('Threat Labeled Regions') | |
ax6.axis('off') | |
# Row 3: Multi-level heatmaps | |
if 'low' in multi_heatmaps and 'medium' in multi_heatmaps and 'high' in multi_heatmaps: | |
ax7 = fig.add_subplot(gs[2, 0]) | |
ax7.imshow(multi_heatmaps['low']) | |
ax7.set_title('Low Threat Heatmap') | |
ax7.axis('off') | |
ax8 = fig.add_subplot(gs[2, 1]) | |
ax8.imshow(multi_heatmaps['medium']) | |
ax8.set_title('Medium Threat Heatmap') | |
ax8.axis('off') | |
ax9 = fig.add_subplot(gs[2, 2]) | |
ax9.imshow(multi_heatmaps['high']) | |
ax9.set_title('High Threat Heatmap') | |
ax9.axis('off') | |
else: | |
# If multi-level heatmaps not available, show combined heatmap in larger space | |
ax7 = fig.add_subplot(gs[2, 0:3]) | |
ax7.imshow(heatmap_overlay) | |
ax7.set_title('Combined Threat Heatmap') | |
ax7.axis('off') | |
# Add threat summary in text box | |
ax10 = fig.add_subplot(gs[2, 3]) | |
ax10.axis('off') | |
summary_text = self._format_summary_text(image_pair_results['threat_summary'], image_pair_results['smi_score']) | |
ax10.text(0, 0.5, summary_text, fontsize=10, va='center', ha='left', wrap=True) | |
ax10.set_title('Threat Summary') | |
# Add overall title | |
plt.suptitle(f"Deepfake Detection Analysis", fontsize=16) | |
# Adjust layout and save | |
plt.tight_layout(rect=[0, 0, 1, 0.97]) | |
plt.savefig(output_path, dpi=dpi, bbox_inches='tight') | |
plt.close() | |
return output_path | |
def _format_summary_text(self, threat_summary, smi_score): | |
""" | |
Format threat summary as text for display | |
""" | |
text = f"SMI Score: {smi_score:.4f}\n" | |
text += f"(1.0 = identical, 0.0 = different)\n\n" | |
text += f"Total regions: {threat_summary['total_regions']}\n\n" | |
text += f"Threat counts:\n" | |
text += f" Low: {threat_summary['threat_counts']['low']}\n" | |
text += f" Medium: {threat_summary['threat_counts']['medium']}\n" | |
text += f" High: {threat_summary['threat_counts']['high']}\n\n" | |
if threat_summary['max_threat']: | |
text += f"Maximum threat: {threat_summary['max_threat']['level'].upper()}\n" | |
text += f" ({threat_summary['max_threat']['percentage']:.1f}%)\n\n" | |
text += f"Average difference: {threat_summary['average_difference']:.1f}%" | |
return text | |
def create_interactive_comparison(self, image_pair_results, output_path): | |
""" | |
Create an HTML file with interactive comparison of all processing stages | |
Args: | |
image_pair_results: Dictionary containing all processing results | |
output_path: Path to save the HTML file | |
Returns: | |
Path to the generated HTML file | |
""" | |
# Create output directory for individual images | |
output_dir = os.path.dirname(output_path) | |
images_dir = os.path.join(output_dir, 'images') | |
if not os.path.exists(images_dir): | |
os.makedirs(images_dir) | |
# Get base filename for outputs | |
base_name = os.path.basename(output_path).split('.')[0] | |
# Save individual images for HTML display | |
image_paths = {} | |
# Save original and modified images | |
original_path = os.path.join(images_dir, f"{base_name}_original.png") | |
modified_path = os.path.join(images_dir, f"{base_name}_modified.png") | |
self.img_processor.save_image(image_pair_results['original_image'], original_path) | |
self.img_processor.save_image(image_pair_results['modified_image'], modified_path) | |
image_paths['original_image_path'] = os.path.relpath(original_path, output_dir) | |
image_paths['modified_image_path'] = os.path.relpath(modified_path, output_dir) | |
# Save difference and threshold images | |
difference_path = os.path.join(images_dir, f"{base_name}_difference.png") | |
threshold_path = os.path.join(images_dir, f"{base_name}_threshold.png") | |
self.img_processor.save_image(image_pair_results['difference_image'], difference_path) | |
self.img_processor.save_image(image_pair_results['threshold_image'], threshold_path) | |
image_paths['difference_image_path'] = os.path.relpath(difference_path, output_dir) | |
image_paths['threshold_image_path'] = os.path.relpath(threshold_path, output_dir) | |
# Save annotated and labeled images | |
annotated_path = os.path.join(images_dir, f"{base_name}_annotated.png") | |
labeled_path = os.path.join(images_dir, f"{base_name}_labeled.png") | |
self.img_processor.save_image(image_pair_results['annotated_image'], annotated_path) | |
self.img_processor.save_image(image_pair_results['labeled_image'], labeled_path) | |
image_paths['annotated_image_path'] = os.path.relpath(annotated_path, output_dir) | |
image_paths['labeled_image_path'] = os.path.relpath(labeled_path, output_dir) | |
# Save heatmap overlay | |
heatmap_path = os.path.join(images_dir, f"{base_name}_heatmap.png") | |
self.img_processor.save_image(image_pair_results['heatmap_overlay'], heatmap_path) | |
image_paths['heatmap_overlay_path'] = os.path.relpath(heatmap_path, output_dir) | |
# Save multi-level heatmaps if available | |
multi_heatmaps = image_pair_results.get('multi_heatmaps', {}) | |
if 'low' in multi_heatmaps and 'medium' in multi_heatmaps and 'high' in multi_heatmaps: | |
low_path = os.path.join(images_dir, f"{base_name}_heatmap_low.png") | |
medium_path = os.path.join(images_dir, f"{base_name}_heatmap_medium.png") | |
high_path = os.path.join(images_dir, f"{base_name}_heatmap_high.png") | |
self.img_processor.save_image(multi_heatmaps['low'], low_path) | |
self.img_processor.save_image(multi_heatmaps['medium'], medium_path) | |
self.img_processor.save_image(multi_heatmaps['high'], high_path) | |
image_paths['low_heatmap_path'] = os.path.relpath(low_path, output_dir) | |
image_paths['medium_heatmap_path'] = os.path.relpath(medium_path, output_dir) | |
image_paths['high_heatmap_path'] = os.path.relpath(high_path, output_dir) | |
# Format threat summary for HTML display | |
threat_summary_text = self._format_summary_text( | |
image_pair_results['threat_summary'], | |
image_pair_results['smi_score'] | |
) | |
# Read HTML template | |
template_path = os.path.join(os.path.dirname(__file__), 'templates', 'interactive_comparison.html') | |
with open(template_path, 'r') as f: | |
html_template = f.read() | |
# Replace placeholders with actual values | |
for key, value in image_paths.items(): | |
html_template = html_template.replace(f"{{{{{key}}}}}", value) | |
# Replace threat summary | |
html_template = html_template.replace("{{threat_summary}}", threat_summary_text) | |
# Write HTML file | |
with open(output_path, 'w') as f: | |
f.write(html_template) | |
print(f"Interactive comparison saved to: {output_path}") | |
return output_path | |
def process_and_visualize(self, image1_path, image2_path, output_dir, model_path=None, threshold=30, min_area=100): | |
""" | |
Process an image pair and create comprehensive visualization | |
Args: | |
image1_path: Path to first image | |
image2_path: Path to second image | |
output_dir: Directory to save outputs | |
model_path: Path to AI model (optional) | |
threshold: Threshold for difference detection | |
min_area: Minimum area for region detection | |
Returns: | |
Path to the generated comparison visualization | |
""" | |
# Initialize components | |
detector = DeepfakeDetector(model_path) | |
labeler = ThreatLabeler() | |
heatmap_gen = HeatmapGenerator() | |
# Create output directory | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
# Get base filename for outputs | |
base_name = os.path.splitext(os.path.basename(image1_path))[0] | |
# Step 1: Verification Module - Process the image pair | |
print(f"Processing images: {image1_path} and {image2_path}") | |
detection_results = detector.process_image_pair(image1_path, image2_path, threshold, min_area) | |
# Step 2: Labeling System - Label detected regions by threat level | |
original_image = self.img_processor.load_image(image1_path) | |
modified_image = self.img_processor.load_image(image2_path) | |
labeled_image, labeled_regions = labeler.label_regions( | |
original_image, detection_results['difference_image'], detection_results['bounding_boxes']) | |
# Get threat summary | |
threat_summary = labeler.get_threat_summary(labeled_regions) | |
# Step 3: Heatmap Visualization - Generate heatmaps for threat visualization | |
heatmap_overlay = heatmap_gen.generate_threat_heatmap(original_image, labeled_regions) | |
multi_heatmaps = heatmap_gen.generate_multi_level_heatmap(original_image, labeled_regions) | |
# Combine all results | |
all_results = { | |
'original_image': original_image, | |
'modified_image': modified_image, | |
'difference_image': detection_results['difference_image'], | |
'threshold_image': detection_results['threshold_image'], | |
'annotated_image': detection_results['annotated_image'], | |
'labeled_image': labeled_image, | |
'heatmap_overlay': heatmap_overlay, | |
'multi_heatmaps': multi_heatmaps, | |
'threat_summary': threat_summary, | |
'smi_score': detection_results['smi_score'], | |
'bounding_boxes': detection_results['bounding_boxes'] | |
} | |
# Create and save comparison visualization | |
grid_output_path = os.path.join(output_dir, f"{base_name}_comparison.png") | |
self.create_comparison_grid(all_results, grid_output_path) | |
# Create interactive HTML comparison | |
html_output_path = os.path.join(output_dir, f"{base_name}_interactive.html") | |
self.create_interactive_comparison(all_results, html_output_path) | |
print(f"Comparison visualization saved to: {grid_output_path}") | |
print(f"Interactive HTML comparison saved to: {html_output_path}") | |
return html_output_path # Return the interactive HTML path as the primary output | |
def batch_process_directory(input_dir, output_dir, model_path=None, threshold=30, min_area=100): | |
""" | |
Process all image pairs in a directory and create comparison visualizations | |
Args: | |
input_dir: Directory containing input images | |
output_dir: Directory to save outputs | |
model_path: Path to AI model (optional) | |
threshold: Threshold for difference detection | |
min_area: Minimum area for region detection | |
Returns: | |
List of paths to generated HTML comparison files | |
""" | |
# Ensure output directory exists | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
# Get all image files in input directory | |
image_files = [f for f in os.listdir(input_dir) | |
if f.lower().endswith(('.png', '.jpg', '.jpeg'))] | |
if not image_files: | |
print(f"No image files found in {input_dir}") | |
return [] | |
# Group images for comparison (assuming pairs with _original and _modified suffixes) | |
original_images = [f for f in image_files if '_original' in f] | |
modified_images = [f for f in image_files if '_modified' in f] | |
# If we don't have clear pairs, just process consecutive images | |
if not (original_images and modified_images): | |
# Process images in pairs (1&2, 3&4, etc.) | |
if len(image_files) < 2: | |
print("Need at least 2 images to compare") | |
return [] | |
image_pairs = [(image_files[i], image_files[i+1]) | |
for i in range(0, len(image_files)-1, 2)] | |
print(f"No _original/_modified naming pattern found. Processing {len(image_pairs)} consecutive pairs.") | |
else: | |
# Match original and modified pairs | |
image_pairs = [] | |
for orig in original_images: | |
base_name = orig.replace('_original', '') | |
for mod in modified_images: | |
if base_name in mod: | |
image_pairs.append((orig, mod)) | |
break | |
print(f"Found {len(image_pairs)} original/modified image pairs.") | |
if not image_pairs: | |
print("No valid image pairs found to process") | |
return [] | |
# Initialize comparison interface | |
interface = ComparisonInterface() | |
# Process each image pair | |
html_paths = [] | |
for img1, img2 in image_pairs: | |
img1_path = os.path.join(input_dir, img1) | |
img2_path = os.path.join(input_dir, img2) | |
print(f"\n{'='*50}") | |
print(f"Processing pair: {img1} and {img2}") | |
print(f"{'='*50}") | |
# Process and create comparison visualization | |
html_path = interface.process_and_visualize( | |
img1_path, img2_path, output_dir, | |
model_path, threshold, min_area | |
) | |
html_paths.append(html_path) | |
print(f"\n{'='*50}") | |
print(f"Overall Summary: Processed {len(image_pairs)} image pairs") | |
print(f"{'='*50}") | |
print(f"All comparison visualizations saved to: {output_dir}") | |
return html_paths | |
if __name__ == "__main__": | |
import argparse | |
import webbrowser | |
parser = argparse.ArgumentParser(description='Deepfake Detection Comparison Interface') | |
parser.add_argument('--input_dir', type=str, required=True, help='Directory containing input images') | |
parser.add_argument('--output_dir', type=str, required=True, help='Directory to save output visualizations') | |
parser.add_argument('--model_path', type=str, help='Path to Nvidia AI model (optional)') | |
parser.add_argument('--threshold', type=int, default=30, help='Threshold for difference detection (0-255)') | |
parser.add_argument('--min_area', type=int, default=100, help='Minimum area for region detection') | |
parser.add_argument('--open_browser', action='store_true', help='Automatically open HTML results in browser') | |
args = parser.parse_args() | |
# Process all images in the directory | |
html_paths = batch_process_directory( | |
args.input_dir, args.output_dir, | |
args.model_path, args.threshold, args.min_area | |
) | |
# Open the first result in browser if requested | |
if args.open_browser and html_paths: | |
print(f"\nOpening first result in web browser: {html_paths[0]}") | |
webbrowser.open('file://' + os.path.abspath(html_paths[0])) | |
print("\nTo view interactive results, open the HTML files in your web browser.") | |
print("Example: file://" + os.path.abspath(html_paths[0]) if html_paths else "") | |
print("\nDone!") |