hb-setosys commited on
Commit
0225f43
·
verified ·
1 Parent(s): 9b6b6b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -41
app.py CHANGED
@@ -4,7 +4,6 @@ import numpy as np
4
  import torch
5
  from ultralytics import YOLO
6
  from sort import Sort
7
- import gradio as gr
8
 
9
  # Load YOLOv12x model
10
  MODEL_PATH = "yolov12x.pt"
@@ -14,32 +13,44 @@ model = YOLO(MODEL_PATH)
14
  TRUCK_CLASS_ID = 7 # "truck"
15
 
16
  # Initialize SORT tracker
17
- tracker = Sort(max_age=20, min_hits=3, iou_threshold=0.3) # Improved tracking stability
18
 
19
  # Minimum confidence threshold for detection
20
- CONFIDENCE_THRESHOLD = 0.4 # Adjusted to capture more trucks
21
 
22
  # Distance threshold to avoid duplicate counts
23
  DISTANCE_THRESHOLD = 50
24
 
25
  # Dictionary to define keyword-based time intervals
26
  TIME_INTERVALS = {
27
- "one": 1, "two": 2, "three": 3, "four": 4, "five": 5,
28
- "six": 6, "seven": 7, "eight": 8, "nine": 9, "ten": 10, "eleven": 11
 
 
 
 
 
 
 
 
 
29
  }
30
 
 
31
  def determine_time_interval(video_filename):
32
- """ Determines frame skip interval based on keywords in the filename. """
33
  for keyword, interval in TIME_INTERVALS.items():
34
  if keyword in video_filename:
 
35
  return interval
36
- return 5 # Default interval
 
 
37
 
38
  def count_unique_trucks(video_path):
39
- """ Counts unique trucks in a video using YOLOv12x and SORT tracking. """
40
  cap = cv2.VideoCapture(video_path)
41
  if not cap.isOpened():
42
- return {"Error": "Unable to open video file."}
43
 
44
  unique_truck_ids = set()
45
  truck_history = {}
@@ -51,13 +62,14 @@ def count_unique_trucks(video_path):
51
  video_filename = os.path.basename(video_path).lower()
52
 
53
  # Determine the dynamic time interval based on filename keywords
54
- #time_interval = determine_time_interval(video_filename)
55
- time_interval = 7
56
  # Get total frames in the video
57
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
58
 
59
- # Dynamically adjust frame skipping based on FPS and movement density
60
- frame_skip = max(1, min(fps * time_interval // 2, total_frames // 10))
 
 
61
 
62
  frame_count = 0
63
 
@@ -68,7 +80,7 @@ def count_unique_trucks(video_path):
68
 
69
  frame_count += 1
70
  if frame_count % frame_skip != 0:
71
- continue # Skip frames based on interval
72
 
73
  # Run YOLOv12x inference
74
  results = model(frame, verbose=False)
@@ -84,39 +96,35 @@ def count_unique_trucks(video_path):
84
  x1, y1, x2, y2 = map(int, box.xyxy[0]) # Get bounding box
85
  detections.append([x1, y1, x2, y2, confidence])
86
 
87
- # Convert detections to numpy array for SORT
88
- detections = np.array(detections) if len(detections) > 0 else np.empty((0, 5))
 
89
 
90
- # Update SORT tracker
91
- tracked_objects = tracker.update(detections)
 
92
 
93
- # Track movement history to avoid duplicate counts
94
- for obj in tracked_objects:
95
- truck_id = int(obj[4]) # Unique ID assigned by SORT
96
- x1, y1, x2, y2 = obj[:4] # Get bounding box coordinates
97
 
98
- truck_center = (x1 + x2) / 2, (y1 + y2) / 2 # Calculate truck center
 
 
 
99
 
100
- # Entry-exit zone logic (e.g., bottom 20% of the frame)
101
- frame_height, frame_width = frame.shape[:2]
102
- entry_line = frame_height * 0.8 # Bottom 20% of the frame
103
- exit_line = frame_height * 0.2 # Top 20% of the frame
104
 
105
- if truck_id not in truck_history:
106
- # New truck detected
107
- truck_history[truck_id] = {
108
- "position": truck_center,
109
- "crossed_entry": truck_center[1] > entry_line,
110
- "crossed_exit": False
111
- }
112
- continue
113
-
114
- # If the truck crosses from entry to exit, count it
115
- if truck_history[truck_id]["crossed_entry"] and truck_center[1] < exit_line:
116
- truck_history[truck_id]["crossed_exit"] = True
117
- unique_truck_ids.add(truck_id)
118
 
119
  cap.release()
 
120
  return {"Total Unique Trucks": len(unique_truck_ids)}
121
 
122
  # Gradio UI function
@@ -125,6 +133,7 @@ def analyze_video(video_file):
125
  return "\n".join([f"{key}: {value}" for key, value in result.items()])
126
 
127
  # Define Gradio interface
 
128
  iface = gr.Interface(
129
  fn=analyze_video,
130
  inputs=gr.Video(label="Upload Video"),
@@ -135,4 +144,4 @@ iface = gr.Interface(
135
 
136
  # Launch the Gradio app
137
  if __name__ == "__main__":
138
- iface.launch()
 
4
  import torch
5
  from ultralytics import YOLO
6
  from sort import Sort
 
7
 
8
  # Load YOLOv12x model
9
  MODEL_PATH = "yolov12x.pt"
 
13
  TRUCK_CLASS_ID = 7 # "truck"
14
 
15
  # Initialize SORT tracker
16
+ tracker = Sort()
17
 
18
  # Minimum confidence threshold for detection
19
+ CONFIDENCE_THRESHOLD = 0.5
20
 
21
  # Distance threshold to avoid duplicate counts
22
  DISTANCE_THRESHOLD = 50
23
 
24
  # Dictionary to define keyword-based time intervals
25
  TIME_INTERVALS = {
26
+ "one": 1,
27
+ "two": 2,
28
+ "three": 3,
29
+ "four": 4,
30
+ "five": 5,
31
+ "six": 6,
32
+ "seven": 7,
33
+ "eight": 8,
34
+ "nine": 9,
35
+ "ten": 10,
36
+ "eleven": 11
37
  }
38
 
39
+
40
  def determine_time_interval(video_filename):
41
+ print(f"Checking filename: {video_filename}") # Debugging
42
  for keyword, interval in TIME_INTERVALS.items():
43
  if keyword in video_filename:
44
+ print(f"Matched keyword: {keyword} -> Interval: {interval}") # Debugging
45
  return interval
46
+ print("No keyword match, using default interval: 5") # Debugging
47
+ return 5 # Default interval if no keyword matches
48
+
49
 
50
  def count_unique_trucks(video_path):
 
51
  cap = cv2.VideoCapture(video_path)
52
  if not cap.isOpened():
53
+ return "Error: Unable to open video file."
54
 
55
  unique_truck_ids = set()
56
  truck_history = {}
 
62
  video_filename = os.path.basename(video_path).lower()
63
 
64
  # Determine the dynamic time interval based on filename keywords
65
+ time_interval = determine_time_interval(video_filename)
 
66
  # Get total frames in the video
67
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
68
 
69
+ # Ensure frame_skip does not exceed total frames
70
+ frame_skip = min(fps * time_interval, total_frames)
71
+
72
+ #frame_skip = fps * time_interval # Convert time interval to frame count
73
 
74
  frame_count = 0
75
 
 
80
 
81
  frame_count += 1
82
  if frame_count % frame_skip != 0:
83
+ continue # Skip frames to process only every 5 seconds
84
 
85
  # Run YOLOv12x inference
86
  results = model(frame, verbose=False)
 
96
  x1, y1, x2, y2 = map(int, box.xyxy[0]) # Get bounding box
97
  detections.append([x1, y1, x2, y2, confidence])
98
 
99
+ if len(detections) > 0:
100
+ detections = np.array(detections)
101
+ tracked_objects = tracker.update(detections)
102
 
103
+ for obj in tracked_objects:
104
+ truck_id = int(obj[4]) # Unique ID assigned by SORT
105
+ x1, y1, x2, y2 = obj[:4] # Get the bounding box coordinates
106
 
107
+ truck_center = (x1 + x2) / 2, (y1 + y2) / 2 # Calculate the center of the truck
 
 
 
108
 
109
+ # If truck is already in history, check the movement distance
110
+ if truck_id in truck_history:
111
+ last_position = truck_history[truck_id]["position"]
112
+ distance = np.linalg.norm(np.array(truck_center) - np.array(last_position))
113
 
114
+ if distance > DISTANCE_THRESHOLD:
115
+ # If the truck moved significantly, count as new
116
+ unique_truck_ids.add(truck_id)
 
117
 
118
+ else:
119
+ # If truck is not in history, add it
120
+ truck_history[truck_id] = {
121
+ "frame_count": frame_count,
122
+ "position": truck_center
123
+ }
124
+ unique_truck_ids.add(truck_id)
 
 
 
 
 
 
125
 
126
  cap.release()
127
+
128
  return {"Total Unique Trucks": len(unique_truck_ids)}
129
 
130
  # Gradio UI function
 
133
  return "\n".join([f"{key}: {value}" for key, value in result.items()])
134
 
135
  # Define Gradio interface
136
+ import gradio as gr
137
  iface = gr.Interface(
138
  fn=analyze_video,
139
  inputs=gr.Video(label="Upload Video"),
 
144
 
145
  # Launch the Gradio app
146
  if __name__ == "__main__":
147
+ iface.launch()