Sompote commited on
Commit
f225817
·
verified ·
1 Parent(s): ce57fb7

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +1469 -0
  2. requirements.txt +28 -0
app.py ADDED
@@ -0,0 +1,1469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import cv2
4
+ import requests
5
+ import json
6
+ import base64
7
+ from PIL import Image
8
+ import io
9
+ import os
10
+ from dotenv import load_dotenv
11
+ from collections import defaultdict
12
+ import time
13
+ from skimage.metrics import structural_similarity as ssim
14
+
15
+ # Load environment variables
16
+ load_dotenv()
17
+
18
+ # Define API endpoint from environment variable
19
+ API_URL = os.getenv("API_URL", "http://122.155.170.240:81")
20
+ print(f"Using API URL: {API_URL}")
21
+ DEFAULT_CONFIDENCE = float(os.getenv("DEFAULT_CONFIDENCE_THRESHOLD", "0.25"))
22
+
23
+ def calculate_iou(box1, box2):
24
+ """Calculate Intersection over Union (IoU) between two bounding boxes"""
25
+ x1 = max(box1[0], box2[0])
26
+ y1 = max(box1[1], box2[1])
27
+ x2 = min(box1[2], box2[2])
28
+ y2 = min(box1[3], box2[3])
29
+
30
+ intersection = max(0, x2 - x1) * max(0, y2 - y1)
31
+ area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
32
+ area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
33
+ union = area1 + area2 - intersection
34
+
35
+ return intersection / union if union > 0 else 0
36
+
37
+ def calculate_bbox_similarity(bbox1, bbox2):
38
+ """Calculate similarity between two bounding boxes using IoU and center distance"""
39
+ try:
40
+ # Calculate IoU
41
+ iou = calculate_iou(bbox1, bbox2)
42
+
43
+ # Calculate center distance
44
+ center1 = get_box_center(bbox1)
45
+ center2 = get_box_center(bbox2)
46
+
47
+ if center1 is None or center2 is None:
48
+ return 0.0
49
+
50
+ distance = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)
51
+
52
+ # Normalize distance based on bbox size
53
+ bbox_size = max(bbox1[2] - bbox1[0], bbox1[3] - bbox1[1])
54
+ normalized_distance = distance / max(bbox_size, 1)
55
+
56
+ # Combine IoU and distance for final similarity score
57
+ similarity = iou * 0.7 + max(0, 1 - normalized_distance * 0.3) * 0.3
58
+
59
+ return similarity
60
+ except Exception as e:
61
+ return 0.0
62
+
63
+ def get_box_center(bbox):
64
+ """Calculate center point of bounding box"""
65
+ try:
66
+ # Handle different bbox formats (x,y,w,h) or (x1,y1,x2,y2)
67
+ if len(bbox) == 4:
68
+ if bbox[2] < bbox[0] or bbox[3] < bbox[1]: # If it's x1,y1,x2,y2 format
69
+ x = (bbox[0] + bbox[2]) / 2
70
+ y = (bbox[1] + bbox[3]) / 2
71
+ else: # If it's x,y,w,h format
72
+ x = bbox[0] + bbox[2]/2
73
+ y = bbox[1] + bbox[3]/2
74
+ else:
75
+ return None
76
+ return (x, y)
77
+ except Exception as e:
78
+ return None
79
+
80
+ def calculate_movement(prev_center, curr_center, min_movement=10):
81
+ """Calculate if there's significant movement between frames"""
82
+ try:
83
+ if prev_center is None or curr_center is None:
84
+ return False
85
+ dx = curr_center[0] - prev_center[0]
86
+ dy = curr_center[1] - prev_center[1]
87
+ distance = np.sqrt(dx*dx + dy*dy)
88
+ return distance > min_movement
89
+ except Exception as e:
90
+ return False
91
+
92
+ def extract_bbox_image(frame, bbox):
93
+ """Extract image region from bounding box"""
94
+ try:
95
+ if frame is None or len(bbox) != 4:
96
+ return None
97
+
98
+ # Convert bbox to integers and ensure valid coordinates
99
+ x1, y1, x2, y2 = map(int, bbox)
100
+
101
+ # Handle different bbox formats
102
+ if x2 < x1 or y2 < y1: # If it's x,y,w,h format
103
+ x1, y1, w, h = bbox
104
+ x2, y2 = x1 + w, y1 + h
105
+
106
+ # Ensure coordinates are within frame bounds
107
+ h, w = frame.shape[:2]
108
+ x1 = max(0, min(x1, w-1))
109
+ y1 = max(0, min(y1, h-1))
110
+ x2 = max(x1+1, min(x2, w))
111
+ y2 = max(y1+1, min(y2, h))
112
+
113
+ # Extract region
114
+ bbox_img = frame[y1:y2, x1:x2]
115
+
116
+ # Resize to standard size for comparison (64x64)
117
+ if bbox_img.size > 0:
118
+ bbox_img = cv2.resize(bbox_img, (64, 64))
119
+ return bbox_img
120
+ return None
121
+ except Exception as e:
122
+ return None
123
+
124
+ def calculate_histogram_similarity(img1, img2):
125
+ """Calculate histogram-based similarity between two images"""
126
+ try:
127
+ if img1 is None or img2 is None:
128
+ return 0.0
129
+
130
+ # Convert to HSV for better color comparison
131
+ hsv1 = cv2.cvtColor(img1, cv2.COLOR_BGR2HSV)
132
+ hsv2 = cv2.cvtColor(img2, cv2.COLOR_BGR2HSV)
133
+
134
+ # Calculate histograms
135
+ hist1 = cv2.calcHist([hsv1], [0, 1, 2], None, [50, 60, 60], [0, 180, 0, 256, 0, 256])
136
+ hist2 = cv2.calcHist([hsv2], [0, 1, 2], None, [50, 60, 60], [0, 180, 0, 256, 0, 256])
137
+
138
+ # Compare histograms using correlation
139
+ correlation = cv2.compareHist(hist1, hist2, cv2.HISTCMP_CORREL)
140
+
141
+ # Normalize to 0-1 range
142
+ return max(0, correlation)
143
+ except Exception as e:
144
+ return 0.0
145
+
146
+ def calculate_ssim_similarity(img1, img2):
147
+ """Calculate Structural Similarity Index (SSIM) between two images"""
148
+ try:
149
+ if img1 is None or img2 is None:
150
+ return 0.0
151
+
152
+ # Convert to grayscale
153
+ gray1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
154
+ gray2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
155
+
156
+ # Calculate SSIM
157
+ similarity_index = ssim(gray1, gray2)
158
+
159
+ # Normalize to 0-1 range (SSIM can be negative)
160
+ return max(0, (similarity_index + 1) / 2)
161
+ except Exception as e:
162
+ return 0.0
163
+
164
+ def calculate_feature_similarity(img1, img2):
165
+ """Calculate feature-based similarity using ORB features"""
166
+ try:
167
+ if img1 is None or img2 is None:
168
+ return 0.0
169
+
170
+ # Convert to grayscale
171
+ gray1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
172
+ gray2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
173
+
174
+ # Initialize ORB detector
175
+ orb = cv2.ORB_create(nfeatures=50)
176
+
177
+ # Find keypoints and descriptors
178
+ kp1, des1 = orb.detectAndCompute(gray1, None)
179
+ kp2, des2 = orb.detectAndCompute(gray2, None)
180
+
181
+ if des1 is None or des2 is None or len(des1) < 5 or len(des2) < 5:
182
+ return 0.0
183
+
184
+ # Match features
185
+ bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
186
+ matches = bf.match(des1, des2)
187
+
188
+ # Calculate similarity based on good matches
189
+ if len(matches) > 0:
190
+ # Sort matches by distance
191
+ matches = sorted(matches, key=lambda x: x.distance)
192
+ good_matches = [m for m in matches if m.distance < 50] # Threshold for good matches
193
+
194
+ # Similarity based on ratio of good matches
195
+ similarity = len(good_matches) / max(len(kp1), len(kp2))
196
+ return min(1.0, similarity)
197
+
198
+ return 0.0
199
+ except Exception as e:
200
+ return 0.0
201
+
202
+ def calculate_enhanced_bbox_similarity(bbox1, bbox2, frame1=None, frame2=None):
203
+ """Enhanced similarity calculation combining geometric and visual features"""
204
+ try:
205
+ # Geometric similarity (IoU + distance) - 40% weight
206
+ geometric_similarity = calculate_bbox_similarity(bbox1, bbox2)
207
+
208
+ # If no frames provided, use only geometric similarity
209
+ if frame1 is None or frame2 is None:
210
+ return geometric_similarity
211
+
212
+ # Extract image regions from bounding boxes
213
+ img1 = extract_bbox_image(frame1, bbox1)
214
+ img2 = extract_bbox_image(frame2, bbox2)
215
+
216
+ if img1 is None or img2 is None:
217
+ return geometric_similarity
218
+
219
+ # Visual similarity components
220
+ hist_similarity = calculate_histogram_similarity(img1, img2) # Color similarity
221
+ ssim_similarity = calculate_ssim_similarity(img1, img2) # Structural similarity
222
+ feature_similarity = calculate_feature_similarity(img1, img2) # Feature similarity
223
+
224
+ # Combine all similarities with weights
225
+ final_similarity = (
226
+ geometric_similarity * 0.4 + # Geometric (IoU + distance)
227
+ hist_similarity * 0.25 + # Color histogram
228
+ ssim_similarity * 0.25 + # Structural similarity
229
+ feature_similarity * 0.1 # Feature matching
230
+ )
231
+
232
+ return min(1.0, final_similarity)
233
+
234
+ except Exception as e:
235
+ return calculate_bbox_similarity(bbox1, bbox2) # Fallback to geometric only
236
+
237
+ class TrackedObject:
238
+ def __init__(self, obj_id, obj_class, bbox):
239
+ self.id = obj_id
240
+ self.class_name = obj_class
241
+ self.alternative_classes = set() # Track alternative classes (e.g., person when primary is motorcycle)
242
+ self.trajectory = [] # List of center points
243
+ self.bboxes = [] # List of bounding boxes
244
+ self.frame_images = [] # Store recent frame images for visual comparison
245
+ self.counted = False
246
+ self.last_seen = 0 # Frame number when last seen
247
+ self.first_seen = 0 # Frame number when first seen
248
+ self.frames_in_red_zone = 0 # Number of consecutive frames in red zone
249
+ self.warning_triggered = False # Whether warning has been triggered
250
+ self.red_zone_entry_frame = None # Frame when object entered red zone
251
+ self.similarity_scores = [] # Track similarity scores over time
252
+ self.add_detection(bbox)
253
+
254
+ def update_class(self, new_class):
255
+ """Update object class, handling motorcycle+person combinations"""
256
+ # Prioritize motorcycle over person (motorcycle with rider)
257
+ if self.class_name == 'person' and new_class == 'motorcycle':
258
+ self.alternative_classes.add(self.class_name)
259
+ self.class_name = new_class
260
+ elif self.class_name == 'motorcycle' and new_class == 'person':
261
+ self.alternative_classes.add(new_class)
262
+ # Keep motorcycle as primary class
263
+ elif new_class != self.class_name:
264
+ # Different class detected, add to alternatives
265
+ self.alternative_classes.add(new_class)
266
+
267
+ def get_primary_class(self):
268
+ """Get the primary class for counting purposes"""
269
+ # Always prioritize motorcycle if it's been detected
270
+ if 'motorcycle' in [self.class_name] or 'motorcycle' in self.alternative_classes:
271
+ return 'motorcycle'
272
+ return self.class_name
273
+
274
+ def add_detection(self, bbox, frame_image=None):
275
+ try:
276
+ center = get_box_center(bbox)
277
+ if center is not None:
278
+ self.trajectory.append(center)
279
+ self.bboxes.append(bbox)
280
+
281
+ # Store frame image for visual comparison
282
+ if frame_image is not None:
283
+ self.frame_images.append(frame_image.copy())
284
+
285
+ # Keep only recent history to prevent memory issues
286
+ if len(self.trajectory) > 50:
287
+ self.trajectory = self.trajectory[-25:]
288
+ self.bboxes = self.bboxes[-25:]
289
+ self.frame_images = self.frame_images[-25:] if self.frame_images else []
290
+ except Exception as e:
291
+ pass
292
+
293
+ def has_movement(self, min_movement=10):
294
+ try:
295
+ if len(self.trajectory) < 2:
296
+ return False
297
+ return calculate_movement(self.trajectory[-2], self.trajectory[-1], min_movement)
298
+ except Exception as e:
299
+ return False
300
+
301
+ def update_red_zone_status(self, is_in_red_zone, frame_number):
302
+ """Update red zone status and handle warnings"""
303
+ if is_in_red_zone:
304
+ if self.red_zone_entry_frame is None:
305
+ self.red_zone_entry_frame = frame_number
306
+ # Mark as entered red zone immediately when first detected
307
+ return "entered"
308
+ self.frames_in_red_zone += 1
309
+
310
+ # Check if warning should be triggered using configurable threshold
311
+ if self.frames_in_red_zone > state.warning_frame_threshold and not self.warning_triggered:
312
+ self.warning_triggered = True
313
+ return "warning" # Return warning to indicate warning should be shown
314
+ else:
315
+ # Object left red zone, reset counters
316
+ if self.red_zone_entry_frame is not None:
317
+ # Object was in red zone and now left
318
+ self.frames_in_red_zone = 0
319
+ self.red_zone_entry_frame = None
320
+ self.warning_triggered = False
321
+ return "exited"
322
+
323
+ return None
324
+
325
+ def get_similarity_with(self, other_bbox, current_frame=None, similarity_threshold=0.5):
326
+ """Calculate enhanced similarity with another bounding box using visual comparison"""
327
+ if len(self.bboxes) == 0:
328
+ return 0.0
329
+
330
+ current_bbox = self.bboxes[-1]
331
+
332
+ # Get the most recent frame image for comparison
333
+ previous_frame = self.frame_images[-1] if self.frame_images else None
334
+
335
+ # Use enhanced similarity calculation with visual comparison
336
+ similarity = calculate_enhanced_bbox_similarity(
337
+ current_bbox,
338
+ other_bbox,
339
+ previous_frame,
340
+ current_frame
341
+ )
342
+
343
+ # Store similarity score for debugging
344
+ self.similarity_scores.append({
345
+ 'frame': state.frame_count,
346
+ 'similarity': similarity,
347
+ 'bbox': other_bbox,
348
+ 'method': 'enhanced' if (previous_frame is not None and current_frame is not None) else 'geometric'
349
+ })
350
+
351
+ # Keep only recent similarity scores to prevent memory issues
352
+ if len(self.similarity_scores) > 20:
353
+ self.similarity_scores = self.similarity_scores[-10:]
354
+
355
+ return similarity
356
+
357
+ def is_similar_object(obj1, obj2, similarity_threshold=0.35):
358
+ """Check if two objects are similar based on class, position and bounding box similarity"""
359
+ try:
360
+ # Allow cross-class matching for motorcycle and person (same object - person on motorcycle)
361
+ class1, class2 = obj1['class'], obj2['class']
362
+
363
+ # Check if classes are compatible (same class or motorcycle+person combination)
364
+ compatible_classes = (
365
+ class1 == class2 or # Same class
366
+ (class1 == 'motorcycle' and class2 == 'person') or # Person on motorcycle
367
+ (class1 == 'person' and class2 == 'motorcycle') # Motorcycle with person
368
+ )
369
+
370
+ if not compatible_classes:
371
+ return False
372
+
373
+ box1 = obj1['bbox']
374
+ box2 = obj2['bbox']
375
+
376
+ # Convert to x1,y1,x2,y2 format if needed
377
+ if len(box1) == 4 and len(box2) == 4:
378
+ if box1[2] < box1[0] or box1[3] < box1[1]: # Already in x1,y1,x2,y2
379
+ bbox1 = box1
380
+ else: # Convert from x,y,w,h to x1,y1,x2,y2
381
+ bbox1 = [box1[0], box1[1], box1[0] + box1[2], box1[1] + box1[3]]
382
+
383
+ if box2[2] < box2[0] or box2[3] < box2[1]: # Already in x1,y1,x2,y2
384
+ bbox2 = box2
385
+ else: # Convert from x,y,w,h to x1,y1,x2,y2
386
+ bbox2 = [box2[0], box2[1], box2[0] + box2[2], box2[1] + box2[3]]
387
+
388
+ similarity = calculate_bbox_similarity(bbox1, bbox2)
389
+
390
+ # Use lower threshold for motorcycle+person combinations
391
+ if class1 != class2 and ('motorcycle' in [class1, class2] and 'person' in [class1, class2]):
392
+ return similarity > (similarity_threshold * 0.7) # 30% more lenient for cross-class
393
+
394
+ return similarity > similarity_threshold
395
+ return False
396
+ except Exception as e:
397
+ return False
398
+
399
+ # Global state for protection area and previous detections
400
+ class State:
401
+ def __init__(self):
402
+ self.protection_points = [] # Store clicked points
403
+ self.detected_segments = []
404
+ self.segment_image = None
405
+ self.selected_segments = []
406
+ self.previous_detections = None
407
+ self.cached_protection_area = None
408
+ self.current_image = None # Store current image for drawing
409
+ self.original_dims = None # Store original image dimensions
410
+ self.display_dims = None # Store display dimensions
411
+ self.tracked_objects = {} # Dictionary of tracked objects
412
+ self.next_obj_id = 0 # Counter for generating unique object IDs
413
+ self.object_count = defaultdict(int) # Count by class
414
+ self.frame_count = 0 # Count processed frames
415
+ self.red_zone_passed_objects = defaultdict(int) # Objects that passed through red zone
416
+ self.red_zone_warnings = [] # Store warning messages
417
+ self.time_window = 10 # Configurable time window for similarity comparison
418
+ self.similarity_threshold = 0.35 # Configurable similarity threshold (lowered for better matching)
419
+ self.warning_frame_threshold = 3 # Configurable warning threshold (frames in red zone)
420
+ # Enhanced red zone tracking
421
+ self.red_zone_entered_objects = defaultdict(int) # All objects that entered red zone
422
+ self.red_zone_current_objects = defaultdict(list) # Objects currently in red zone
423
+ self.red_zone_exited_objects = defaultdict(int) # Objects that exited red zone
424
+
425
+ def reset_tracking(self):
426
+ """Reset all tracking data"""
427
+ self.tracked_objects = {}
428
+ self.next_obj_id = 0
429
+ self.object_count = defaultdict(int)
430
+ self.frame_count = 0
431
+ self.red_zone_passed_objects = defaultdict(int)
432
+ self.red_zone_warnings = []
433
+ # Reset enhanced red zone tracking
434
+ self.red_zone_entered_objects = defaultdict(int)
435
+ self.red_zone_current_objects = defaultdict(list)
436
+ self.red_zone_exited_objects = defaultdict(int)
437
+
438
+ state = State()
439
+
440
+ def image_to_bytes(image):
441
+ """Convert PIL Image to bytes for API request"""
442
+ # Log original image size
443
+ original_width, original_height = image.size
444
+ print(f"Original image dimensions: {original_width}x{original_height}")
445
+
446
+ # Convert image to bytes without resizing
447
+ img_byte_arr = io.BytesIO()
448
+ image.save(img_byte_arr, format='PNG')
449
+ print(f"Sending image with original dimensions: {original_width}x{original_height}")
450
+
451
+ return img_byte_arr.getvalue()
452
+
453
+ def base64_to_image(base64_str):
454
+ """Convert base64 string to OpenCV image"""
455
+ img_data = base64.b64decode(base64_str)
456
+ nparr = np.frombuffer(img_data, np.uint8)
457
+ return cv2.imdecode(nparr, cv2.IMREAD_COLOR)
458
+
459
+ def opencv_to_pil(opencv_image):
460
+ """Convert OpenCV image to PIL format"""
461
+ # Convert from BGR to RGB for PIL
462
+ rgb_image = cv2.cvtColor(opencv_image, cv2.COLOR_BGR2RGB)
463
+ return Image.fromarray(rgb_image)
464
+
465
+ def scale_point_to_original(x, y):
466
+ """Scale display coordinates back to original image coordinates"""
467
+ if state.original_dims is None or state.display_dims is None:
468
+ return x, y
469
+
470
+ orig_w, orig_h = state.original_dims
471
+ disp_w, disp_h = state.display_dims
472
+
473
+ # Calculate scaling factors
474
+ scale_x = orig_w / disp_w
475
+ scale_y = orig_h / disp_h
476
+
477
+ # Scale the coordinates
478
+ orig_x = int(x * scale_x)
479
+ orig_y = int(y * scale_y)
480
+
481
+ return orig_x, orig_y
482
+
483
+ def scale_points_to_display(points):
484
+ """Scale points from original image coordinates to display coordinates"""
485
+ if state.original_dims is None or state.display_dims is None:
486
+ return points
487
+
488
+ orig_w, orig_h = state.original_dims
489
+ disp_w, disp_h = state.display_dims
490
+
491
+ # Calculate scaling factors
492
+ scale_x = disp_w / orig_w
493
+ scale_y = disp_h / orig_h
494
+
495
+ # Scale all points
496
+ display_points = []
497
+ for point in points:
498
+ x = int(point[0] * scale_x)
499
+ y = int(point[1] * scale_y)
500
+ display_points.append([x, y])
501
+
502
+ return display_points
503
+
504
+ def draw_protection_area(image):
505
+ """Draw protection area points and lines on the image"""
506
+ img = image.copy()
507
+ points = state.protection_points
508
+
509
+ # Draw existing points and lines
510
+ if len(points) > 0:
511
+ # Convert points to numpy array
512
+ points_array = np.array(points, dtype=np.int32)
513
+
514
+ # Draw lines between points
515
+ if len(points) > 1:
516
+ cv2.polylines(img, [points_array],
517
+ True if len(points) == 4 else False,
518
+ (0, 255, 0), 2)
519
+
520
+ # Draw points with numbers
521
+ for i, point in enumerate(points):
522
+ cv2.circle(img, tuple(point), 5, (0, 0, 255), -1)
523
+ cv2.putText(img, str(i+1),
524
+ (point[0]+10, point[1]+10),
525
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
526
+
527
+ # Fill polygon with semi-transparent color if we have at least 3 points
528
+ if len(points) >= 3:
529
+ overlay = img.copy()
530
+ cv2.fillPoly(overlay, [points_array], (0, 255, 0))
531
+ cv2.addWeighted(overlay, 0.3, img, 0.7, 0, img)
532
+
533
+ return img
534
+
535
+ def update_preview(video):
536
+ if video is None:
537
+ return None, [], gr.update(visible=False)
538
+ cap = cv2.VideoCapture(video)
539
+ ret, frame = cap.read()
540
+ cap.release()
541
+ if ret:
542
+ # Reset state
543
+ state.protection_points = []
544
+ state.detected_segments = []
545
+ state.segment_image = None
546
+ state.selected_segments = []
547
+ state.previous_detections = None
548
+ state.cached_protection_area = None
549
+
550
+ # Store original frame and its dimensions
551
+ state.current_image = frame.copy() # Store the original frame
552
+ state.original_dims = (frame.shape[1], frame.shape[0]) # (width, height)
553
+ state.display_dims = state.original_dims # Set display dims same as original
554
+
555
+ # Convert to RGB without resizing
556
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
557
+ return frame_rgb, gr.update(choices=[], value=[], visible=False)
558
+ return None, gr.update(choices=[], value=[], visible=False)
559
+
560
+ def handle_image_click(evt: gr.SelectData, img):
561
+ """Handle mouse clicks on the image"""
562
+ if len(state.protection_points) >= 4:
563
+ # Reset points if we already have 4
564
+ state.protection_points = []
565
+
566
+ if state.current_image is None:
567
+ return img, "Error: No image loaded"
568
+
569
+ # Get click coordinates from the event - these are now in original scale
570
+ click_x, click_y = evt.index[0], evt.index[1]
571
+
572
+ # Add point directly (no scaling needed as we're working with original coordinates)
573
+ state.protection_points.append([click_x, click_y])
574
+
575
+ # Create a copy of the current image for display
576
+ display_img = state.current_image.copy()
577
+
578
+ # Draw points and lines
579
+ for i, point in enumerate(state.protection_points):
580
+ # Draw point
581
+ cv2.circle(display_img, (point[0], point[1]), 5, (0, 0, 255), -1)
582
+ cv2.putText(display_img, str(i+1),
583
+ (point[0] + 10, point[1] + 10),
584
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
585
+
586
+ # Draw lines between points
587
+ if len(state.protection_points) > 1:
588
+ points_array = np.array(state.protection_points, dtype=np.int32)
589
+
590
+ # Draw lines
591
+ cv2.polylines(display_img, [points_array],
592
+ True if len(state.protection_points) == 4 else False,
593
+ (0, 255, 0), 2)
594
+
595
+ # Fill polygon with semi-transparent color if we have at least 3 points
596
+ if len(state.protection_points) >= 3:
597
+ overlay = display_img.copy()
598
+ cv2.fillPoly(overlay, [points_array], (0, 255, 0))
599
+ cv2.addWeighted(overlay, 0.3, display_img, 0.7, 0, display_img)
600
+
601
+ # Convert to RGB for display
602
+ display_img_rgb = cv2.cvtColor(display_img, cv2.COLOR_BGR2RGB)
603
+
604
+ # Return the image and status
605
+ return display_img_rgb, f"Selected {len(state.protection_points)} points\nCoordinates: {state.protection_points}"
606
+
607
+ def reset_points():
608
+ """Reset protection points"""
609
+ state.protection_points = []
610
+ if state.current_image is not None:
611
+ # Convert original image to RGB for display
612
+ display_img_rgb = cv2.cvtColor(state.current_image.copy(), cv2.COLOR_BGR2RGB)
613
+ return display_img_rgb, "Points reset"
614
+ return None, "Points reset"
615
+
616
+ def detect_rail_segments(image):
617
+ """Detect rail segments using the API"""
618
+ try:
619
+ # Log original image dimensions
620
+ width, height = image.size
621
+ print(f"Detecting rail segments on image with dimensions: {width}x{height}")
622
+
623
+ files = {"file": image_to_bytes(image)}
624
+ response = requests.post(
625
+ f"{API_URL}/detect/rail-segment",
626
+ files=files,
627
+ timeout=60
628
+ )
629
+
630
+ if response.status_code == 200:
631
+ result = response.json()
632
+ if "segments" in result:
633
+ return result["segments"], base64_to_image(result["image_base64"])
634
+ else:
635
+ return [], None
636
+ else:
637
+ print(f"API error: {response.status_code} - Image size was {width}x{height}")
638
+ return [], None
639
+ except Exception as e:
640
+ print(f"Error in detect_rail_segments: {str(e)}")
641
+ return [], None
642
+
643
+ def extract_protection_area(first_frame):
644
+ """Extract and cache protection area points using rail segment detection"""
645
+ try:
646
+ # Log original frame dimensions
647
+ height, width = first_frame.shape[:2]
648
+ print(f"Extracting protection area from frame with dimensions: {width}x{height}")
649
+
650
+ # Convert frame to PIL Image without resizing
651
+ first_frame_pil = Image.fromarray(cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB))
652
+
653
+ # Verify PIL image dimensions
654
+ pil_width, pil_height = first_frame_pil.size
655
+ print(f"PIL Image dimensions before API call: {pil_width}x{pil_height}")
656
+
657
+ # Detect rail segments
658
+ segments, segment_img = detect_rail_segments(first_frame_pil)
659
+
660
+ if segments and len(segments) > 0:
661
+ # Verify segment image dimensions
662
+ if segment_img is not None:
663
+ seg_height, seg_width = segment_img.shape[:2]
664
+ print(f"Received segment image dimensions: {seg_width}x{seg_height}")
665
+
666
+ # Only resize if dimensions don't match
667
+ if (seg_width, seg_height) != (width, height):
668
+ print(f"Resizing segment image from {seg_width}x{seg_height} to {width}x{height}")
669
+ segment_img = cv2.resize(segment_img, (width, height), interpolation=cv2.INTER_LANCZOS4)
670
+
671
+ # Store segments and image
672
+ state.detected_segments = segments
673
+ state.segment_image = segment_img
674
+
675
+ # Create segment choices with more detailed information
676
+ segment_choices = []
677
+ for i, segment in enumerate(segments):
678
+ # Extract mask dimensions for verification
679
+ mask_points = segment.get('mask', [])
680
+ if mask_points:
681
+ mask_x = [p[0] for p in mask_points]
682
+ mask_y = [p[1] for p in mask_points]
683
+ mask_width = max(mask_x) - min(mask_x)
684
+ mask_height = max(mask_y) - min(mask_y)
685
+ print(f"Segment {i+1} mask dimensions: {mask_width}x{mask_height}")
686
+
687
+ choice_text = f"Segment {i+1} (Confidence: {segment['confidence']:.2f})"
688
+ segment_choices.append(choice_text)
689
+
690
+ state.selected_segments = segment_choices # Select all segments by default
691
+
692
+ # Use the first segment's mask as protection area
693
+ segment = segments[0]
694
+ if 'mask' in segment and segment['mask']:
695
+ mask_points = segment['mask']
696
+ # Convert to list of [x,y] points and ensure integer values
697
+ mask_points = [[int(float(x)), int(float(y))] for x, y in mask_points]
698
+ if len(mask_points) >= 3: # Need at least 3 points for a valid polygon
699
+ state.cached_protection_area = mask_points
700
+
701
+ # Convert segment image to RGB for display without resizing
702
+ if segment_img is not None:
703
+ display_img = cv2.cvtColor(segment_img, cv2.COLOR_BGR2RGB)
704
+ return True, "Protection area extracted successfully", display_img
705
+
706
+ return False, "Invalid mask points in segment", None
707
+ return False, "No valid rail segments detected", None
708
+
709
+ except Exception as e:
710
+ print(f"Error in extract_protection_area: {str(e)}")
711
+ return False, f"Error extracting protection area: {str(e)}", None
712
+
713
+ def get_segment_index(choice_text):
714
+ """Extract segment index from choice text"""
715
+ try:
716
+ # Extract index from "Segment X (Confidence: Y)" format
717
+ return int(choice_text.split()[1]) - 1
718
+ except:
719
+ return -1
720
+
721
+ def update_object_tracking(objects_in_area, current_frame=None):
722
+ """Update object tracking with new detections"""
723
+ try:
724
+ current_tracked = set() # Keep track of objects seen in this frame
725
+ current_warnings = [] # Collect warnings for this frame
726
+
727
+ # Clear current objects list for this frame
728
+ state.red_zone_current_objects = defaultdict(list)
729
+
730
+ # Match new detections with existing tracked objects
731
+ for obj in objects_in_area:
732
+ try:
733
+ if 'bbox' not in obj or 'class' not in obj:
734
+ continue
735
+
736
+ bbox = obj['bbox']
737
+ obj_class = obj['class']
738
+ is_in_red_zone = obj.get('in_protection_area', False)
739
+ matched = False
740
+ best_match_id = None
741
+ best_similarity = 0.0
742
+
743
+ # Try to match with existing tracked objects using cross-class similarity
744
+ for obj_id, tracked in state.tracked_objects.items():
745
+ # Check if object was seen recently (within time window)
746
+ if state.frame_count - tracked.last_seen <= state.time_window:
747
+ # Create temporary objects for similarity comparison
748
+ temp_obj1 = {'class': tracked.class_name, 'bbox': tracked.bboxes[-1] if tracked.bboxes else bbox}
749
+ temp_obj2 = {'class': obj_class, 'bbox': bbox}
750
+
751
+ if is_similar_object(temp_obj1, temp_obj2, state.similarity_threshold):
752
+ # Use enhanced similarity calculation with visual comparison
753
+ similarity = tracked.get_similarity_with(bbox, current_frame)
754
+
755
+ # Use the best match above threshold
756
+ if similarity > best_similarity:
757
+ best_similarity = similarity
758
+ best_match_id = obj_id
759
+
760
+ # If good match found, update existing object
761
+ if best_match_id is not None:
762
+ tracked = state.tracked_objects[best_match_id]
763
+ tracked.add_detection(bbox, current_frame) # Pass current frame
764
+ tracked.update_class(obj_class) # Update class information
765
+ tracked.last_seen = state.frame_count
766
+ current_tracked.add(best_match_id)
767
+ matched = True
768
+
769
+ # Check red zone status and handle state changes
770
+ zone_status = tracked.update_red_zone_status(is_in_red_zone, state.frame_count)
771
+ primary_class = tracked.get_primary_class() # Use primary class for counting
772
+
773
+ if zone_status == "entered":
774
+ # Object just entered red zone - count it immediately
775
+ if not tracked.counted:
776
+ tracked.counted = True
777
+ state.red_zone_entered_objects[primary_class] += 1
778
+
779
+ elif zone_status == "warning":
780
+ warning_msg = f"⚠️ WARNING: {primary_class} (ID: {tracked.id}) has been in red zone for {tracked.frames_in_red_zone} frames!"
781
+ current_warnings.append(warning_msg)
782
+ state.red_zone_warnings.append({
783
+ 'frame': state.frame_count,
784
+ 'object_id': tracked.id,
785
+ 'class': primary_class,
786
+ 'frames_in_zone': tracked.frames_in_red_zone,
787
+ 'message': warning_msg
788
+ })
789
+
790
+ elif zone_status == "exited":
791
+ # Object exited red zone
792
+ state.red_zone_exited_objects[primary_class] += 1
793
+
794
+ # Add to current objects in red zone if still in zone
795
+ if is_in_red_zone:
796
+ display_class = primary_class
797
+ if tracked.alternative_classes:
798
+ display_class += f" ({'+'.join(sorted(tracked.alternative_classes))})"
799
+
800
+ state.red_zone_current_objects[primary_class].append({
801
+ 'id': tracked.id,
802
+ 'frames_in_zone': tracked.frames_in_red_zone,
803
+ 'entry_frame': tracked.red_zone_entry_frame,
804
+ 'display_class': display_class
805
+ })
806
+
807
+ # If no match found, create new tracked object
808
+ if not matched:
809
+ new_obj = TrackedObject(state.next_obj_id, obj_class, bbox)
810
+ new_obj.add_detection(bbox, current_frame) # Pass current frame
811
+ new_obj.last_seen = state.frame_count
812
+ new_obj.first_seen = state.frame_count
813
+ state.tracked_objects[state.next_obj_id] = new_obj
814
+ current_tracked.add(state.next_obj_id)
815
+
816
+ # Check red zone status for new object
817
+ zone_status = new_obj.update_red_zone_status(is_in_red_zone, state.frame_count)
818
+ primary_class = new_obj.get_primary_class()
819
+
820
+ if zone_status == "entered":
821
+ # New object entered red zone immediately
822
+ new_obj.counted = True
823
+ state.red_zone_entered_objects[primary_class] += 1
824
+
825
+ # Add to current objects in red zone
826
+ state.red_zone_current_objects[primary_class].append({
827
+ 'id': new_obj.id,
828
+ 'frames_in_zone': new_obj.frames_in_red_zone,
829
+ 'entry_frame': new_obj.red_zone_entry_frame,
830
+ 'display_class': primary_class
831
+ })
832
+
833
+ state.next_obj_id += 1
834
+
835
+ except Exception as e:
836
+ continue
837
+
838
+ # Update objects not seen in current frame
839
+ for obj_id, tracked in state.tracked_objects.items():
840
+ if obj_id not in current_tracked:
841
+ # Object not seen in current frame, update red zone status
842
+ zone_status = tracked.update_red_zone_status(False, state.frame_count)
843
+ if zone_status == "exited":
844
+ # Object exited red zone
845
+ primary_class = tracked.get_primary_class()
846
+ state.red_zone_exited_objects[primary_class] += 1
847
+
848
+ # Remove objects that haven't been seen for a while
849
+ if state.frame_count > state.time_window:
850
+ to_remove = []
851
+ for obj_id, tracked in state.tracked_objects.items():
852
+ if state.frame_count - tracked.last_seen > state.time_window * 2: # Remove after 2x time window
853
+ # If object was in red zone when lost, count as exited
854
+ if tracked.red_zone_entry_frame is not None:
855
+ primary_class = tracked.get_primary_class()
856
+ state.red_zone_exited_objects[primary_class] += 1
857
+ to_remove.append(obj_id)
858
+
859
+ for obj_id in to_remove:
860
+ del state.tracked_objects[obj_id]
861
+
862
+ # Store current warnings
863
+ if current_warnings:
864
+ print(f"Frame {state.frame_count} Warnings: {current_warnings}")
865
+
866
+ except Exception as e:
867
+ print(f"Error in update_object_tracking: {str(e)}")
868
+
869
+ def get_red_zone_summary():
870
+ """Generate comprehensive summary of objects in red zone with proper grouping"""
871
+ summary = []
872
+
873
+ # Header
874
+ summary.append("🔴 RED ZONE MONITORING REPORT")
875
+ summary.append("=" * 40)
876
+
877
+ # Objects that entered red zone (all time)
878
+ if state.red_zone_entered_objects:
879
+ summary.append("\n📊 OBJECTS ENTERED RED ZONE:")
880
+ total_entered = sum(state.red_zone_entered_objects.values())
881
+ summary.append(f"Total objects entered: {total_entered}")
882
+
883
+ for obj_class, count in sorted(state.red_zone_entered_objects.items()):
884
+ summary.append(f" • {obj_class}: {count}")
885
+ else:
886
+ summary.append("\n📊 OBJECTS ENTERED RED ZONE:")
887
+ summary.append("No objects have entered the red zone yet")
888
+
889
+ # Objects currently in red zone
890
+ current_total = sum(len(objects) for objects in state.red_zone_current_objects.values())
891
+ if current_total > 0:
892
+ summary.append(f"\n🚨 CURRENTLY IN RED ZONE ({current_total} objects):")
893
+
894
+ for obj_class, objects in sorted(state.red_zone_current_objects.items()):
895
+ if objects:
896
+ summary.append(f" {obj_class} ({len(objects)} objects):")
897
+ for obj_info in objects:
898
+ display_class = obj_info.get('display_class', obj_class)
899
+ summary.append(f" - ID {obj_info['id']}: {obj_info['frames_in_zone']} frames (entered: frame {obj_info['entry_frame']}) [{display_class}]")
900
+ else:
901
+ summary.append("\n🚨 CURRENTLY IN RED ZONE:")
902
+ summary.append("No objects currently in red zone")
903
+
904
+ # Objects that exited red zone
905
+ if state.red_zone_exited_objects:
906
+ summary.append("\n✅ OBJECTS EXITED RED ZONE:")
907
+ total_exited = sum(state.red_zone_exited_objects.values())
908
+ summary.append(f"Total objects exited: {total_exited}")
909
+
910
+ for obj_class, count in sorted(state.red_zone_exited_objects.items()):
911
+ summary.append(f" • {obj_class}: {count}")
912
+
913
+ # Recent warnings
914
+ recent_warnings = [w for w in state.red_zone_warnings if state.frame_count - w['frame'] <= 10]
915
+ if recent_warnings:
916
+ summary.append("\n⚠️ RECENT WARNINGS:")
917
+ for warning in recent_warnings[-5:]: # Show last 5 warnings
918
+ summary.append(f" • Frame {warning['frame']}: {warning['class']} (ID: {warning['object_id']}) - {warning['frames_in_zone']} frames in zone")
919
+
920
+ # Statistics summary
921
+ summary.append(f"\n📈 STATISTICS:")
922
+ summary.append(f" • Total unique objects tracked: {len(state.tracked_objects)}")
923
+ summary.append(f" • Active warnings: {len([w for w in state.red_zone_warnings if state.frame_count - w['frame'] <= 5])}")
924
+ summary.append(f" • Frame: {state.frame_count}")
925
+ summary.append(f" • Warning threshold: {state.warning_frame_threshold} frames")
926
+
927
+ # Show object combination info
928
+ combined_objects = 0
929
+ for tracked in state.tracked_objects.values():
930
+ if tracked.alternative_classes:
931
+ combined_objects += 1
932
+
933
+ if combined_objects > 0:
934
+ summary.append(f" • Objects with combined detections: {combined_objects}")
935
+
936
+ return "\n".join(summary)
937
+
938
+ def process_frame(frame, confidence):
939
+ """Process a video frame using cached protection area"""
940
+ try:
941
+ protection_area = []
942
+ if state.selected_segments and state.detected_segments:
943
+ for choice in state.selected_segments:
944
+ idx = get_segment_index(choice)
945
+ if 0 <= idx < len(state.detected_segments):
946
+ segment = state.detected_segments[idx]
947
+ if 'mask' in segment and segment['mask']:
948
+ protection_area = segment['mask']
949
+ break
950
+ elif len(state.protection_points) >= 3:
951
+ protection_area = state.protection_points
952
+
953
+ if not protection_area:
954
+ return None, "Protection area not set. Please extract protection area first."
955
+
956
+ # Ensure frame is valid
957
+ if frame is None or frame.size == 0:
958
+ return None, "Invalid frame"
959
+
960
+ success, buffer = cv2.imencode('.png', frame)
961
+ if not success:
962
+ return None, "Failed to encode frame"
963
+
964
+ files = {
965
+ "file": ("frame.png", buffer.tobytes(), "image/png")
966
+ }
967
+
968
+ protection_area_json = json.dumps(protection_area)
969
+
970
+ data = {
971
+ "protection_area": protection_area_json,
972
+ "confidence_threshold": str(confidence)
973
+ }
974
+
975
+ if state.previous_detections:
976
+ data["previous_detections"] = json.dumps(state.previous_detections)
977
+
978
+ try:
979
+ response = requests.post(
980
+ f"{API_URL}/detect/objects-and-redlight",
981
+ files=files,
982
+ data=data,
983
+ timeout=60
984
+ )
985
+
986
+ if response.status_code == 200:
987
+ result = response.json()
988
+ if not result.get("success"):
989
+ return None, f"API Error: {result.get('detail', 'Unknown error')}"
990
+
991
+ result_data = result.get("result", {})
992
+ if not result_data:
993
+ return None, "No result data received"
994
+
995
+ red_light_info = result_data.get("red_light", {})
996
+ red_light_detected = red_light_info.get("detected", False)
997
+ red_light_prob = red_light_info.get("probability", 0)
998
+
999
+ img_base64 = result_data.get("image_base64")
1000
+ if not img_base64:
1001
+ return None, "No image data received from API"
1002
+
1003
+ try:
1004
+ if ',' in img_base64:
1005
+ img_base64 = img_base64.split(',')[1]
1006
+
1007
+ img_data = base64.b64decode(img_base64)
1008
+ nparr = np.frombuffer(img_data, np.uint8)
1009
+ processed_img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
1010
+
1011
+ if processed_img is None or processed_img.size == 0:
1012
+ return None, "Failed to decode image from API response"
1013
+
1014
+ objects_in_area = [obj for obj in result_data.get("objects", [])
1015
+ if obj.get("in_protection_area", False) and
1016
+ 'bbox' in obj and 'class' in obj]
1017
+
1018
+ # Update object tracking
1019
+ state.frame_count += 1
1020
+ update_object_tracking(objects_in_area, processed_img)
1021
+
1022
+ # Cache detections for next frame
1023
+ state.previous_detections = objects_in_area
1024
+
1025
+ processed_img_rgb = cv2.cvtColor(processed_img, cv2.COLOR_BGR2RGB)
1026
+
1027
+ status = []
1028
+ status.append(f"Red Light: {'YES' if red_light_detected else 'NO'} ({red_light_prob:.2f})")
1029
+
1030
+ # Add enhanced red zone summary
1031
+ red_zone_summary = get_red_zone_summary()
1032
+ status.append(f"\n{red_zone_summary}")
1033
+
1034
+ if objects_in_area:
1035
+ status.append("\n📊 CURRENT FRAME DETECTIONS:")
1036
+ for obj in objects_in_area:
1037
+ status.append(f" • {obj['class']} (confidence: {obj['confidence']:.2f})")
1038
+
1039
+ # Add tracking statistics
1040
+ active_objects = len([obj for obj in state.tracked_objects.values()
1041
+ if state.frame_count - obj.last_seen <= 3])
1042
+ status.append(f"\n📈 TRACKING STATS:")
1043
+ status.append(f" • Active tracked objects: {active_objects}")
1044
+ status.append(f" • Frame: {state.frame_count}")
1045
+ status.append(f" • Time window: {state.time_window} frames")
1046
+ status.append(f" • Similarity threshold: {state.similarity_threshold:.2f}")
1047
+
1048
+ return processed_img_rgb, "\n".join(status)
1049
+
1050
+ except Exception as e:
1051
+ return None, f"Error processing detection results: {str(e)}"
1052
+ else:
1053
+ error_detail = f"API Error: {response.status_code}"
1054
+ try:
1055
+ error_json = response.json()
1056
+ if 'detail' in error_json:
1057
+ error_detail += f" - {error_json['detail']}"
1058
+ except:
1059
+ error_detail += f" - {response.text}"
1060
+ return None, error_detail
1061
+
1062
+ except requests.exceptions.Timeout:
1063
+ return None, "API request timed out"
1064
+ except requests.exceptions.ConnectionError:
1065
+ return None, "Could not connect to API server"
1066
+ except Exception as e:
1067
+ return None, f"API request failed: {str(e)}"
1068
+
1069
+ except Exception as e:
1070
+ return None, f"Error processing frame: {str(e)}"
1071
+
1072
+ def process_video(video, confidence=DEFAULT_CONFIDENCE, target_fps=1):
1073
+ """Stream processed frames in real-time using cached protection area"""
1074
+ detection_results = []
1075
+ cap = cv2.VideoCapture(video)
1076
+
1077
+ if not cap.isOpened():
1078
+ yield None, "Error: Could not open video file"
1079
+ return
1080
+
1081
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
1082
+ fps = cap.get(cv2.CAP_PROP_FPS)
1083
+ frame_interval = max(1, int(fps / target_fps))
1084
+ frame_number = 0
1085
+
1086
+ try:
1087
+ while True:
1088
+ ret, frame = cap.read()
1089
+ if not ret:
1090
+ break
1091
+
1092
+ frame_number += 1
1093
+ if frame_number % frame_interval != 0:
1094
+ continue
1095
+
1096
+ # Process frame and get results
1097
+ processed_frame, result = process_frame(frame, confidence)
1098
+
1099
+ if processed_frame is not None:
1100
+ # Frame is already in RGB format from process_frame
1101
+ current_status = f"Processing frame {frame_number}/{total_frames}\n{result}"
1102
+ yield processed_frame, current_status
1103
+ else:
1104
+ current_status = f"Frame {frame_number}: {result}"
1105
+ yield None, current_status
1106
+
1107
+ # Release resources
1108
+ cap.release()
1109
+
1110
+ # Generate final summary
1111
+ final_summary = generate_final_summary()
1112
+ yield None, final_summary
1113
+
1114
+ except Exception as e:
1115
+ yield None, f"Error processing video: {str(e)}"
1116
+ finally:
1117
+ cap.release()
1118
+
1119
+ def generate_final_summary():
1120
+ """Generate comprehensive final summary of video processing"""
1121
+ summary_lines = []
1122
+
1123
+ summary_lines.append("🎬 VIDEO PROCESSING COMPLETE")
1124
+ summary_lines.append("=" * 50)
1125
+
1126
+ # Processing statistics
1127
+ summary_lines.append(f"📊 PROCESSING STATISTICS:")
1128
+ summary_lines.append(f" • Total frames processed: {state.frame_count}")
1129
+ summary_lines.append(f" • Time window used: {state.time_window} frames")
1130
+ summary_lines.append(f" • Similarity threshold: {state.similarity_threshold:.2f}")
1131
+ summary_lines.append(f" • Warning threshold: {state.warning_frame_threshold} frames")
1132
+
1133
+ # Enhanced red zone summary
1134
+ if state.red_zone_entered_objects:
1135
+ summary_lines.append(f"\n🔴 RED ZONE ANALYSIS:")
1136
+ total_entered = sum(state.red_zone_entered_objects.values())
1137
+ total_exited = sum(state.red_zone_exited_objects.values())
1138
+
1139
+ summary_lines.append(f" • Total objects entered red zone: {total_entered}")
1140
+ summary_lines.append(f" • Total objects exited red zone: {total_exited}")
1141
+ summary_lines.append(f" • Objects still in red zone: {total_entered - total_exited}")
1142
+
1143
+ summary_lines.append(f"\n 📋 BREAKDOWN BY OBJECT CLASS:")
1144
+
1145
+ # Combine all object classes that appeared in red zone
1146
+ all_classes = set(state.red_zone_entered_objects.keys()) | set(state.red_zone_exited_objects.keys())
1147
+
1148
+ for obj_class in sorted(all_classes):
1149
+ entered = state.red_zone_entered_objects.get(obj_class, 0)
1150
+ exited = state.red_zone_exited_objects.get(obj_class, 0)
1151
+ still_in = entered - exited
1152
+
1153
+ summary_lines.append(f" {obj_class}:")
1154
+ summary_lines.append(f" - Entered: {entered}")
1155
+ summary_lines.append(f" - Exited: {exited}")
1156
+ summary_lines.append(f" - Still in zone: {still_in}")
1157
+ else:
1158
+ summary_lines.append(f"\n🔴 RED ZONE ANALYSIS:")
1159
+ summary_lines.append(f" • No objects detected in red zone during processing")
1160
+
1161
+ # Object combination analysis
1162
+ combined_objects = []
1163
+ motorcycle_person_combinations = 0
1164
+
1165
+ for obj_id, tracked in state.tracked_objects.items():
1166
+ if tracked.alternative_classes:
1167
+ combo_info = f"ID {obj_id}: {tracked.class_name}"
1168
+ if tracked.alternative_classes:
1169
+ combo_info += f" + {', '.join(sorted(tracked.alternative_classes))}"
1170
+ combined_objects.append(combo_info)
1171
+
1172
+ # Count motorcycle+person combinations specifically
1173
+ if (tracked.class_name == 'motorcycle' and 'person' in tracked.alternative_classes) or \
1174
+ (tracked.class_name == 'person' and 'motorcycle' in tracked.alternative_classes):
1175
+ motorcycle_person_combinations += 1
1176
+
1177
+ if combined_objects:
1178
+ summary_lines.append(f"\n🔗 OBJECT COMBINATIONS DETECTED:")
1179
+ summary_lines.append(f" • Total combined detections: {len(combined_objects)}")
1180
+ summary_lines.append(f" • Motorcycle+Person combinations: {motorcycle_person_combinations}")
1181
+ summary_lines.append(f" • Details:")
1182
+ for combo in combined_objects:
1183
+ summary_lines.append(f" - {combo}")
1184
+
1185
+ # Warning summary
1186
+ if state.red_zone_warnings:
1187
+ summary_lines.append(f"\n⚠️ WARNING SUMMARY:")
1188
+ summary_lines.append(f" • Total warnings generated: {len(state.red_zone_warnings)}")
1189
+
1190
+ # Group warnings by object class
1191
+ warning_by_class = defaultdict(int)
1192
+ for warning in state.red_zone_warnings:
1193
+ warning_by_class[warning['class']] += 1
1194
+
1195
+ for obj_class, count in sorted(warning_by_class.items()):
1196
+ summary_lines.append(f" - {obj_class}: {count} warnings")
1197
+
1198
+ # Show detailed warning log
1199
+ if len(state.red_zone_warnings) > 0:
1200
+ summary_lines.append(f"\n 📋 Warning Log (last 10):")
1201
+ for warning in state.red_zone_warnings[-10:]: # Last 10 warnings
1202
+ summary_lines.append(f" - Frame {warning['frame']}: {warning['class']} (ID: {warning['object_id']}) - {warning['frames_in_zone']} frames in zone")
1203
+ else:
1204
+ summary_lines.append(f"\n⚠️ WARNING SUMMARY:")
1205
+ summary_lines.append(f" • No warnings generated (no objects stayed in red zone > {state.warning_frame_threshold} frames)")
1206
+
1207
+ # Active tracking summary
1208
+ total_tracked = len(state.tracked_objects)
1209
+ if total_tracked > 0:
1210
+ summary_lines.append(f"\n📈 OBJECT TRACKING SUMMARY:")
1211
+ summary_lines.append(f" • Total unique objects tracked: {total_tracked}")
1212
+
1213
+ # Group by primary class
1214
+ objects_by_class = defaultdict(int)
1215
+ for obj in state.tracked_objects.values():
1216
+ primary_class = obj.get_primary_class()
1217
+ objects_by_class[primary_class] += 1
1218
+
1219
+ for obj_class, count in sorted(objects_by_class.items()):
1220
+ summary_lines.append(f" - {obj_class}: {count}")
1221
+
1222
+ summary_lines.append("\n✅ Processing completed successfully!")
1223
+ summary_lines.append("\nNote: Objects detected as both motorcycle and person are counted as motorcycle (person riding motorcycle)")
1224
+
1225
+ return "\n".join(summary_lines)
1226
+
1227
+ def extract_area_from_video(video):
1228
+ if video is None:
1229
+ return None, "Please upload a video", gr.update(choices=[], value=[], visible=False)
1230
+
1231
+ cap = cv2.VideoCapture(video)
1232
+ ret, frame = cap.read()
1233
+ cap.release()
1234
+
1235
+ if not ret:
1236
+ return None, "Could not read video frame", gr.update(choices=[], value=[], visible=False)
1237
+
1238
+ success, message, segment_img = extract_protection_area(frame)
1239
+ if success and segment_img is not None:
1240
+ # Convert segment image to RGB for display
1241
+ segment_img_rgb = cv2.cvtColor(segment_img, cv2.COLOR_BGR2RGB)
1242
+
1243
+ # Create segment choices
1244
+ segment_choices = [f"Segment {i+1} (Confidence: {segment['confidence']:.2f})"
1245
+ for i, segment in enumerate(state.detected_segments)]
1246
+
1247
+ return segment_img_rgb, message, gr.update(choices=segment_choices, value=segment_choices, visible=True)
1248
+ return None, message, gr.update(choices=[], value=[], visible=False)
1249
+
1250
+ def update_selected_segments(selected):
1251
+ if selected is None:
1252
+ selected = []
1253
+ state.selected_segments = selected
1254
+ return gr.update()
1255
+
1256
+ def process_video_wrapper(video, confidence=DEFAULT_CONFIDENCE, target_fps=1, time_window=10, similarity_threshold=0.35, warning_frame_threshold=3):
1257
+ """Wrapper around process_video to handle full-size video processing"""
1258
+ if video is None:
1259
+ yield None, "Please upload a video"
1260
+ return
1261
+
1262
+ # Reset tracking state and update parameters
1263
+ state.reset_tracking()
1264
+ state.time_window = time_window
1265
+ state.similarity_threshold = similarity_threshold
1266
+ state.warning_frame_threshold = warning_frame_threshold
1267
+
1268
+ protection_area = []
1269
+ if state.selected_segments and state.detected_segments:
1270
+ for choice in state.selected_segments:
1271
+ idx = get_segment_index(choice)
1272
+ if 0 <= idx < len(state.detected_segments):
1273
+ segment = state.detected_segments[idx]
1274
+ if 'mask' in segment and segment['mask']:
1275
+ protection_area = segment['mask']
1276
+ break
1277
+ elif len(state.protection_points) >= 3:
1278
+ protection_area = state.protection_points
1279
+
1280
+ if not protection_area:
1281
+ yield None, "Please extract protection area first"
1282
+ return
1283
+
1284
+ try:
1285
+ yield None, f"🚀 Starting video processing...\n⚙️ Time window: {time_window} frames\n⚙️ Similarity threshold: {similarity_threshold:.2f}\n⚙️ Warning threshold: {warning_frame_threshold} frames"
1286
+
1287
+ for frame, status in process_video(video, confidence, target_fps):
1288
+ yield frame, status
1289
+
1290
+ except Exception as e:
1291
+ yield None, f"Error processing video: {str(e)}"
1292
+
1293
+ # Enhanced Gradio interface
1294
+ with gr.Blocks(title="Enhanced Rail Traffic Monitor") as demo:
1295
+ gr.Markdown("""
1296
+ # Enhanced Rail Traffic Monitoring System
1297
+
1298
+ ## Features:
1299
+ - **Smart Object Tracking**: Uses enhanced similarity method combining geometric and visual features
1300
+ - **Visual Similarity Comparison**: Compares actual images within bounding boxes using multiple methods
1301
+ - **Comprehensive Red Zone Monitoring**: Reports ALL objects entering the red zone
1302
+ - **Enhanced Grouping**: Groups objects by class with detailed statistics
1303
+ - **Real-time Status**: Shows objects currently in zone, entered, and exited
1304
+ - **Configurable Warning System**: Alerts when objects stay in red zone for too long
1305
+ - **Configurable Parameters**: Adjust time window, similarity threshold, and warning criteria
1306
+
1307
+ ## Enhanced Similarity Methods:
1308
+ - **Geometric Similarity** (40%): IoU + center distance
1309
+ - **Color Histogram** (25%): HSV color distribution comparison
1310
+ - **Structural Similarity** (25%): SSIM for shape and texture
1311
+ - **Feature Matching** (10%): ORB keypoint matching
1312
+ - **Default Threshold**: 0.35 (more lenient for better object matching)
1313
+
1314
+ ## Red Zone Reporting:
1315
+ - **Objects Entered**: Total count of all objects that entered the red zone
1316
+ - **Currently in Zone**: Real-time list of objects currently in the red zone
1317
+ - **Objects Exited**: Count of objects that have left the red zone
1318
+ - **Detailed Grouping**: All statistics grouped by object class (train, car, person, etc.)
1319
+
1320
+ ## Setup Instructions:
1321
+
1322
+ **Method 1 (Manual Protection Area):**
1323
+ 1. Click 4 points on the image to define protection area
1324
+ 2. Click "Reset Points" to start over
1325
+
1326
+ **Method 2 (Automatic Detection):**
1327
+ 1. Click "Extract Protection Area" to automatically detect rail segments
1328
+
1329
+ **Processing:**
1330
+ 3. Adjust detection confidence, processing frame rate, time window, similarity threshold, and warning threshold
1331
+ 4. Click "Process Video" to analyze
1332
+
1333
+ The system will show comprehensive real-time results including:
1334
+ - All objects that entered the red zone (grouped by class)
1335
+ - Objects currently in red zone with detailed info
1336
+ - Objects that exited the red zone
1337
+ - Enhanced tracking with visual similarity comparison
1338
+ - Configurable warnings for objects staying too long in red zone
1339
+ - Complete tracking statistics
1340
+ """)
1341
+
1342
+ with gr.Row():
1343
+ with gr.Column():
1344
+ video_input = gr.Video(
1345
+ label="Input Video"
1346
+ )
1347
+ with gr.Row():
1348
+ confidence = gr.Slider(
1349
+ minimum=0.0,
1350
+ maximum=1.0,
1351
+ value=DEFAULT_CONFIDENCE,
1352
+ label="Detection Confidence Threshold",
1353
+ info="Minimum confidence for object detection"
1354
+ )
1355
+ fps_slider = gr.Slider(
1356
+ minimum=1,
1357
+ maximum=30,
1358
+ value=1,
1359
+ step=1,
1360
+ label="Processing Frame Rate (FPS)",
1361
+ info="Frames per second to process"
1362
+ )
1363
+
1364
+ with gr.Row():
1365
+ time_window_slider = gr.Slider(
1366
+ minimum=5,
1367
+ maximum=50,
1368
+ value=10,
1369
+ step=1,
1370
+ label="Time Window (frames)",
1371
+ info="Number of frames to consider for object similarity"
1372
+ )
1373
+ similarity_threshold_slider = gr.Slider(
1374
+ minimum=0.1,
1375
+ maximum=0.9,
1376
+ value=0.35,
1377
+ step=0.05,
1378
+ label="Similarity Threshold",
1379
+ info="Threshold for considering objects as the same (higher = stricter)"
1380
+ )
1381
+
1382
+ with gr.Row():
1383
+ warning_threshold_slider = gr.Slider(
1384
+ minimum=1,
1385
+ maximum=20,
1386
+ value=3,
1387
+ step=1,
1388
+ label="Warning Frame Threshold",
1389
+ info="Number of frames in red zone before triggering warning"
1390
+ )
1391
+ with gr.Column():
1392
+ preview_image = gr.Image(
1393
+ label="Click to Select Protection Area (Original Size)",
1394
+ interactive=True,
1395
+ show_label=True
1396
+ )
1397
+
1398
+ # Add segment selection dropdown
1399
+ segment_dropdown = gr.Dropdown(
1400
+ label="Selected Segments",
1401
+ choices=[],
1402
+ multiselect=True,
1403
+ interactive=True,
1404
+ visible=False,
1405
+ value=[]
1406
+ )
1407
+
1408
+ with gr.Row():
1409
+ reset_btn = gr.Button("Reset Points")
1410
+ extract_btn = gr.Button("Extract Protection Area")
1411
+ process_btn = gr.Button("🚀 Process Video")
1412
+
1413
+ with gr.Row():
1414
+ video_output = gr.Image(
1415
+ label="Live Processing Output",
1416
+ streaming=True,
1417
+ interactive=False,
1418
+ show_label=True,
1419
+ container=True,
1420
+ show_download_button=True
1421
+ )
1422
+ text_output = gr.Textbox(
1423
+ label="Detection Results & Red Zone Summary",
1424
+ lines=15,
1425
+ max_lines=20,
1426
+ show_copy_button=True
1427
+ )
1428
+
1429
+ # Handle video upload to populate preview
1430
+ video_input.change(
1431
+ fn=update_preview,
1432
+ inputs=[video_input],
1433
+ outputs=[preview_image, segment_dropdown]
1434
+ )
1435
+
1436
+ extract_btn.click(
1437
+ fn=extract_area_from_video,
1438
+ inputs=[video_input],
1439
+ outputs=[preview_image, text_output, segment_dropdown]
1440
+ )
1441
+
1442
+ segment_dropdown.change(
1443
+ fn=update_selected_segments,
1444
+ inputs=[segment_dropdown],
1445
+ outputs=[segment_dropdown]
1446
+ )
1447
+
1448
+ process_btn.click(
1449
+ fn=process_video_wrapper,
1450
+ inputs=[video_input, confidence, fps_slider, time_window_slider, similarity_threshold_slider, warning_threshold_slider],
1451
+ outputs=[video_output, text_output]
1452
+ )
1453
+
1454
+ # Add click event handler
1455
+ preview_image.select(
1456
+ fn=handle_image_click,
1457
+ inputs=[preview_image],
1458
+ outputs=[preview_image, text_output]
1459
+ )
1460
+
1461
+ # Add reset button handler
1462
+ reset_btn.click(
1463
+ fn=reset_points,
1464
+ inputs=[],
1465
+ outputs=[preview_image, text_output]
1466
+ )
1467
+
1468
+ if __name__ == "__main__":
1469
+ demo.queue().launch()
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core API dependencies
2
+ fastapi==0.110.0
3
+ uvicorn==0.27.1
4
+ python-multipart==0.0.9
5
+ pydantic==2.6.3
6
+
7
+ # Machine Learning and Computer Vision
8
+ torch==2.2.1
9
+ torchvision==0.17.1
10
+ ultralytics==8.1.28
11
+ opencv-python==4.9.0.80
12
+ numpy==1.26.3
13
+ Pillow==10.2.0
14
+ scikit-image==0.22.0
15
+
16
+ # Data handling and utilities
17
+ PyYAML==6.0.1
18
+ requests==2.31.0
19
+ python-dotenv==1.0.1
20
+
21
+ # Frontend (Gradio app)
22
+ gradio==4.19.2
23
+
24
+ # Development and testing
25
+ tqdm>=4.66.0
26
+
27
+ # Additional utilities (built-in modules - no need to install)
28
+ # os, io, sys, json, base64, logging, traceback, pathlib, threading, time, pickle, concurrent.futures, collections