noumanjavaid commited on
Commit
1390aae
·
verified ·
1 Parent(s): 433addc

Upload heatmap_generator.py

Browse files
Files changed (1) hide show
  1. heatmap_generator.py +212 -0
heatmap_generator.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ from image_processor import ImageProcessor
6
+
7
+ class HeatmapGenerator:
8
+ def __init__(self):
9
+ """
10
+ Initialize the heatmap generator for visualizing threat areas
11
+ """
12
+ self.image_processor = ImageProcessor()
13
+
14
+ # Define colormap options
15
+ self.colormap_options = {
16
+ 'hot': cv2.COLORMAP_HOT, # Red-yellow-white, good for high intensity
17
+ 'jet': cv2.COLORMAP_JET, # Blue-cyan-yellow-red, good for range
18
+ 'inferno': cv2.COLORMAP_INFERNO, # Purple-red-yellow, good for threat
19
+ 'plasma': cv2.COLORMAP_PLASMA # Purple-red-yellow, alternative
20
+ }
21
+
22
+ # Default colormap
23
+ self.default_colormap = 'inferno'
24
+
25
+ def generate_heatmap_from_diff(self, diff_image, threshold=0, blur_size=15):
26
+ """
27
+ Generate a heatmap directly from a difference image
28
+
29
+ Args:
30
+ diff_image: Difference image (0-255 range)
31
+ threshold: Minimum difference value to consider (0-255)
32
+ blur_size: Size of Gaussian blur kernel for smoothing
33
+
34
+ Returns:
35
+ Heatmap image
36
+ """
37
+ # Apply threshold to filter out low differences
38
+ _, thresholded = cv2.threshold(diff_image, threshold, 255, cv2.THRESH_TOZERO)
39
+
40
+ # Apply Gaussian blur to smooth the heatmap
41
+ if blur_size > 0:
42
+ blurred = cv2.GaussianBlur(thresholded, (blur_size, blur_size), 0)
43
+ else:
44
+ blurred = thresholded
45
+
46
+ # Apply colormap
47
+ heatmap = cv2.applyColorMap(blurred, self.colormap_options[self.default_colormap])
48
+
49
+ # Convert to RGB for consistent display
50
+ heatmap_rgb = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
51
+
52
+ return heatmap_rgb
53
+
54
+ def generate_heatmap_from_regions(self, image_shape, labeled_regions, sigma=40):
55
+ """
56
+ Generate a heatmap from labeled regions based on threat levels
57
+
58
+ Args:
59
+ image_shape: Shape of the original image (height, width)
60
+ labeled_regions: List of regions with threat levels from ThreatLabeler
61
+ sigma: Standard deviation for Gaussian kernel
62
+
63
+ Returns:
64
+ Heatmap image
65
+ """
66
+ # Create an empty heatmap
67
+ height, width = image_shape[:2]
68
+ heatmap = np.zeros((height, width), dtype=np.float32)
69
+
70
+ # Define threat level weights with increased intensity
71
+ threat_weights = {
72
+ 'low': 0.4,
73
+ 'medium': 0.7,
74
+ 'high': 1.0
75
+ }
76
+
77
+ # Add each region to the heatmap with appropriate weight
78
+ for region in labeled_regions:
79
+ bbox = region['bbox']
80
+ threat_level = region['threat_level']
81
+ diff_percentage = region['difference_percentage']
82
+
83
+ # Calculate center of bounding box
84
+ x, y, w, h = bbox
85
+ center_x, center_y = x + w // 2, y + h // 2
86
+
87
+ # Calculate intensity based on threat level and difference percentage with increased brightness
88
+ intensity = threat_weights[threat_level] * (diff_percentage / 100) * 1.2
89
+
90
+ # Create a Gaussian kernel centered at the region with increased sigma for more circular spread
91
+ y_coords, x_coords = np.ogrid[:height, :width]
92
+ dist_from_center = ((y_coords - center_y) ** 2 + (x_coords - center_x) ** 2) / (2 * sigma ** 2)
93
+ kernel = np.exp(-dist_from_center) * intensity
94
+
95
+ # Add to heatmap
96
+ heatmap += kernel
97
+
98
+ # Normalize heatmap to 0-255 range
99
+ if np.max(heatmap) > 0: # Avoid division by zero
100
+ heatmap = (heatmap / np.max(heatmap) * 255).astype(np.uint8)
101
+ else:
102
+ heatmap = np.zeros((height, width), dtype=np.uint8)
103
+
104
+ # Apply colormap
105
+ colored_heatmap = cv2.applyColorMap(heatmap, self.colormap_options[self.default_colormap])
106
+ colored_heatmap = cv2.cvtColor(colored_heatmap, cv2.COLOR_BGR2RGB)
107
+
108
+ return colored_heatmap
109
+
110
+ def overlay_heatmap(self, original_image, heatmap, alpha=0.6):
111
+ """
112
+ Overlay heatmap on original image
113
+
114
+ Args:
115
+ original_image: Original image
116
+ heatmap: Heatmap image
117
+ alpha: Transparency factor (0-1)
118
+
119
+ Returns:
120
+ Overlaid image
121
+ """
122
+ # Ensure images are the same size
123
+ if original_image.shape[:2] != heatmap.shape[:2]:
124
+ heatmap = cv2.resize(heatmap, (original_image.shape[1], original_image.shape[0]))
125
+
126
+ # Overlay heatmap on original image
127
+ return self.image_processor.overlay_images(original_image, heatmap, alpha)
128
+
129
+ def generate_threat_heatmap(self, image, labeled_regions, overlay=True, alpha=0.6):
130
+ """
131
+ Generate a complete threat heatmap visualization
132
+
133
+ Args:
134
+ image: Original image
135
+ labeled_regions: List of regions with threat levels
136
+ overlay: Whether to overlay on original image
137
+ alpha: Transparency for overlay
138
+
139
+ Returns:
140
+ Heatmap image or overlaid image
141
+ """
142
+ # Generate heatmap from labeled regions
143
+ heatmap = self.generate_heatmap_from_regions(image.shape, labeled_regions)
144
+
145
+ # Overlay on original image if requested
146
+ if overlay:
147
+ return self.overlay_heatmap(image, heatmap, alpha)
148
+ else:
149
+ return heatmap
150
+
151
+ def save_heatmap_visualization(self, image, heatmap, output_path, dpi=300):
152
+ """
153
+ Save a side-by-side visualization of original image and heatmap
154
+
155
+ Args:
156
+ image: Original image
157
+ heatmap: Heatmap image
158
+ output_path: Path to save visualization
159
+ dpi: Resolution for saved image
160
+ """
161
+ # Create figure with two subplots
162
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
163
+
164
+ # Display original image
165
+ ax1.imshow(image)
166
+ ax1.set_title('Original Image')
167
+ ax1.axis('off')
168
+
169
+ # Display heatmap
170
+ ax2.imshow(heatmap)
171
+ ax2.set_title('Threat Heatmap')
172
+ ax2.axis('off')
173
+
174
+ # Save figure
175
+ plt.tight_layout()
176
+ plt.savefig(output_path, dpi=dpi, bbox_inches='tight')
177
+ plt.close()
178
+
179
+ def generate_multi_level_heatmap(self, image, labeled_regions):
180
+ """
181
+ Generate separate heatmaps for each threat level
182
+
183
+ Args:
184
+ image: Original image
185
+ labeled_regions: List of regions with threat levels
186
+
187
+ Returns:
188
+ Dictionary with heatmaps for each threat level and combined
189
+ """
190
+ # Create separate lists for each threat level
191
+ low_regions = [r for r in labeled_regions if r['threat_level'] == 'low']
192
+ medium_regions = [r for r in labeled_regions if r['threat_level'] == 'medium']
193
+ high_regions = [r for r in labeled_regions if r['threat_level'] == 'high']
194
+
195
+ # Generate heatmaps for each level
196
+ low_heatmap = self.generate_heatmap_from_regions(image.shape, low_regions)
197
+ medium_heatmap = self.generate_heatmap_from_regions(image.shape, medium_regions)
198
+ high_heatmap = self.generate_heatmap_from_regions(image.shape, high_regions)
199
+
200
+ # Generate combined heatmap
201
+ combined_heatmap = self.generate_heatmap_from_regions(image.shape, labeled_regions)
202
+
203
+ # Overlay all on original image
204
+ combined_overlay = self.overlay_heatmap(image, combined_heatmap)
205
+
206
+ return {
207
+ 'low': low_heatmap,
208
+ 'medium': medium_heatmap,
209
+ 'high': high_heatmap,
210
+ 'combined': combined_heatmap,
211
+ 'overlay': combined_overlay
212
+ }