heatmap / streamlit_app.py
noumanjavaid's picture
Update streamlit_app.py
3958771 verified
import os
import streamlit as st
import tempfile
from PIL import Image
import numpy as np
from comparison_interface import ComparisonInterface
from deepfake_detector import DeepfakeDetector
from image_processor import ImageProcessor
from labeling import ThreatLabeler
from heatmap_generator import HeatmapGenerator
# Set page configuration
st.set_page_config(page_title="Deepfake Detection Analysis", layout="wide")
# Initialize components
comparison = ComparisonInterface()
img_processor = ImageProcessor()
# Custom CSS to improve the UI
st.markdown("""
<style>
.main-header {
font-size: 2.5rem;
font-weight: bold;
color: #1E3A8A;
text-align: center;
margin-bottom: 1rem;
}
.sub-header {
font-size: 1.5rem;
font-weight: bold;
color: #2563EB;
margin-top: 1rem;
margin-bottom: 0.5rem;
}
.info-text {
font-size: 1rem;
color: #4B5563;
}
.stImage {
margin-top: 1rem;
margin-bottom: 1rem;
}
</style>
""", unsafe_allow_html=True)
# App header
st.markdown('<p class="main-header">Deepfake Detection Analysis</p>', unsafe_allow_html=True)
st.markdown('<p class="info-text">Upload original and modified images to analyze potential deepfakes</p>', unsafe_allow_html=True)
# Create columns for file uploaders
col1, col2 = st.columns(2)
with col1:
st.markdown('<p class="sub-header">Original Image</p>', unsafe_allow_html=True)
original_file = st.file_uploader("Upload the original image", type=["jpg", "jpeg", "png"])
with col2:
st.markdown('<p class="sub-header">Modified Image</p>', unsafe_allow_html=True)
modified_file = st.file_uploader("Upload the potentially modified image", type=["jpg", "jpeg", "png"])
# Parameters for analysis
st.markdown('<p class="sub-header">Analysis Parameters</p>', unsafe_allow_html=True)
col1, col2 = st.columns(2)
with col1:
threshold = st.slider("Difference Threshold", min_value=10, max_value=100, value=30,
help="Higher values detect only significant differences")
with col2:
min_area = st.slider("Minimum Detection Area", min_value=50, max_value=500, value=100,
help="Minimum area size to consider as a modified region")
# Process images when both are uploaded
if original_file and modified_file:
# Create temporary directory for processing
with tempfile.TemporaryDirectory() as temp_dir:
# Save uploaded files to temporary directory
original_path = os.path.join(temp_dir, "original.jpg")
modified_path = os.path.join(temp_dir, "modified.jpg")
with open(original_path, "wb") as f:
f.write(original_file.getbuffer())
with open(modified_path, "wb") as f:
f.write(modified_file.getbuffer())
# Create output directory
output_dir = os.path.join(temp_dir, "output")
os.makedirs(output_dir, exist_ok=True)
# Process images and generate visualization
with st.spinner("Processing images and generating analysis..."):
# Initialize components
detector = DeepfakeDetector()
labeler = ThreatLabeler()
heatmap_gen = HeatmapGenerator()
# Step 1: Verification Module - Process the image pair
detection_results = detector.process_image_pair(original_path, modified_path, threshold, min_area)
# Step 2: Labeling System - Label detected regions by threat level
original_image = img_processor.load_image(original_path)
modified_image = img_processor.load_image(modified_path)
labeled_image, labeled_regions = labeler.label_regions(
original_image, detection_results['difference_image'], detection_results['bounding_boxes'])
# Get threat summary
threat_summary = labeler.get_threat_summary(labeled_regions)
# Step 3: Heatmap Visualization - Generate heatmaps for threat visualization
heatmap_overlay = heatmap_gen.generate_threat_heatmap(original_image, labeled_regions)
multi_heatmaps = heatmap_gen.generate_multi_level_heatmap(original_image, labeled_regions)
# Combine all results
all_results = {
'original_image': original_image,
'modified_image': modified_image,
'difference_image': detection_results['difference_image'],
'threshold_image': detection_results['threshold_image'],
'annotated_image': detection_results['annotated_image'],
'labeled_image': labeled_image,
'heatmap_overlay': heatmap_overlay,
'multi_heatmaps': multi_heatmaps,
'threat_summary': threat_summary,
'smi_score': detection_results['smi_score'],
'bounding_boxes': detection_results['bounding_boxes']
}
# Create output directory in a permanent location
output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "comparison_output")
os.makedirs(output_dir, exist_ok=True)
# Generate unique filename based on original image
base_name = os.path.splitext(original_file.name)[0]
output_path = os.path.join(output_dir, f"{base_name}_combined_overlay.png")
# Create and save combined visualization with heatmaps
from combined_visualization import CombinedVisualizer
visualizer = CombinedVisualizer()
combined_results = {
'original_image': original_image,
'difference_image': detection_results['difference_image'],
'bounding_boxes': detection_results['bounding_boxes'],
'multi_heatmaps': multi_heatmaps,
'labeled_regions': labeled_regions
}
combined_path = visualizer.create_combined_visualization(
combined_results, output_path,
alpha_diff=0.4, alpha_low=0.3, alpha_medium=0.4
)
# Create and save comparison visualization
grid_output_path = os.path.join(output_dir, f"{base_name}_comparison.png")
comparison.create_comparison_grid(all_results, grid_output_path)
# Display the comprehensive visualization
st.markdown('<p class="sub-header">Comprehensive Analysis</p>', unsafe_allow_html=True)
st.image(grid_output_path, use_container_width=True)
# Display threat summary
st.markdown('<p class="sub-header">Threat Summary</p>', unsafe_allow_html=True)
st.markdown(f"**SMI Score:** {detection_results['smi_score']:.4f} (1.0 = identical, 0.0 = completely different)")
st.markdown(f"**Total regions detected:** {threat_summary['total_regions']}")
# Create columns for threat counts
col1, col2, col3 = st.columns(3)
with col1:
st.markdown(f"**Low threats:** {threat_summary['threat_counts']['low']}")
with col2:
st.markdown(f"**Medium threats:** {threat_summary['threat_counts']['medium']}")
with col3:
st.markdown(f"**High threats:** {threat_summary['threat_counts']['high']}")
if threat_summary['max_threat']:
st.markdown(f"**Maximum threat:** {threat_summary['max_threat']['level'].upper()} ({threat_summary['max_threat']['percentage']:.1f}%)")
st.markdown(f"**Average difference:** {threat_summary['average_difference']:.1f}%")
else:
# Display instructions when images are not yet uploaded
st.info("Please upload both original and modified images to begin analysis.")
# Display sample image
st.markdown('<p class="sub-header">Sample Analysis Output</p>', unsafe_allow_html=True)
sample_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "comparison_output", "deepfake1_comparison.png")
if os.path.exists(sample_path):
st.image(sample_path, use_container_width=True)
st.caption("Sample analysis showing all detection stages in a single comprehensive view")
else:
st.write("Sample image not available. Please upload images to see the analysis.")