hb-setosys commited on
Commit
7f2b13a
·
verified ·
1 Parent(s): 121700f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -51
app.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
  import torch
5
  from ultralytics import YOLO
6
  from sort import Sort
 
7
 
8
  # Load YOLOv12x model
9
  MODEL_PATH = "yolov12x.pt"
@@ -13,44 +14,32 @@ model = YOLO(MODEL_PATH)
13
  TRUCK_CLASS_ID = 7 # "truck"
14
 
15
  # Initialize SORT tracker
16
- tracker = Sort()
17
 
18
  # Minimum confidence threshold for detection
19
- CONFIDENCE_THRESHOLD = 0.5
20
 
21
  # Distance threshold to avoid duplicate counts
22
  DISTANCE_THRESHOLD = 50
23
 
24
  # Dictionary to define keyword-based time intervals
25
  TIME_INTERVALS = {
26
- "one": 1,
27
- "two": 2,
28
- "three": 3,
29
- "four": 4,
30
- "five": 5,
31
- "six": 6,
32
- "seven": 7,
33
- "eight": 8,
34
- "nine": 9,
35
- "ten": 10,
36
- "eleven": 11
37
  }
38
 
39
-
40
  def determine_time_interval(video_filename):
41
- print(f"Checking filename: {video_filename}") # Debugging
42
  for keyword, interval in TIME_INTERVALS.items():
43
  if keyword in video_filename:
44
- print(f"Matched keyword: {keyword} -> Interval: {interval}") # Debugging
45
  return interval
46
- print("No keyword match, using default interval: 5") # Debugging
47
- return 5 # Default interval if no keyword matches
48
-
49
 
50
  def count_unique_trucks(video_path):
 
51
  cap = cv2.VideoCapture(video_path)
52
  if not cap.isOpened():
53
- return "Error: Unable to open video file."
54
 
55
  unique_truck_ids = set()
56
  truck_history = {}
@@ -62,15 +51,13 @@ def count_unique_trucks(video_path):
62
  video_filename = os.path.basename(video_path).lower()
63
 
64
  # Determine the dynamic time interval based on filename keywords
65
- #time_interval = determine_time_interval(video_filename)
66
- time_interval = 7
67
  # Get total frames in the video
68
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
69
 
70
- # Ensure frame_skip does not exceed total frames
71
- frame_skip = min(fps * time_interval, total_frames)
72
-
73
- #frame_skip = fps * time_interval # Convert time interval to frame count
74
 
75
  frame_count = 0
76
 
@@ -81,7 +68,7 @@ def count_unique_trucks(video_path):
81
 
82
  frame_count += 1
83
  if frame_count % frame_skip != 0:
84
- continue # Skip frames to process only every 5 seconds
85
 
86
  # Run YOLOv12x inference
87
  results = model(frame, verbose=False)
@@ -97,35 +84,39 @@ def count_unique_trucks(video_path):
97
  x1, y1, x2, y2 = map(int, box.xyxy[0]) # Get bounding box
98
  detections.append([x1, y1, x2, y2, confidence])
99
 
100
- if len(detections) > 0:
101
- detections = np.array(detections)
102
- tracked_objects = tracker.update(detections)
103
 
104
- for obj in tracked_objects:
105
- truck_id = int(obj[4]) # Unique ID assigned by SORT
106
- x1, y1, x2, y2 = obj[:4] # Get the bounding box coordinates
107
 
108
- truck_center = (x1 + x2) / 2, (y1 + y2) / 2 # Calculate the center of the truck
 
 
 
109
 
110
- # If truck is already in history, check the movement distance
111
- if truck_id in truck_history:
112
- last_position = truck_history[truck_id]["position"]
113
- distance = np.linalg.norm(np.array(truck_center) - np.array(last_position))
114
 
115
- if distance > DISTANCE_THRESHOLD:
116
- # If the truck moved significantly, count as new
117
- unique_truck_ids.add(truck_id)
 
118
 
119
- else:
120
- # If truck is not in history, add it
121
- truck_history[truck_id] = {
122
- "frame_count": frame_count,
123
- "position": truck_center
124
- }
125
- unique_truck_ids.add(truck_id)
 
126
 
127
- cap.release()
 
 
 
128
 
 
129
  return {"Total Unique Trucks": len(unique_truck_ids)}
130
 
131
  # Gradio UI function
@@ -134,7 +125,6 @@ def analyze_video(video_file):
134
  return "\n".join([f"{key}: {value}" for key, value in result.items()])
135
 
136
  # Define Gradio interface
137
- import gradio as gr
138
  iface = gr.Interface(
139
  fn=analyze_video,
140
  inputs=gr.Video(label="Upload Video"),
@@ -145,4 +135,4 @@ iface = gr.Interface(
145
 
146
  # Launch the Gradio app
147
  if __name__ == "__main__":
148
- iface.launch()
 
4
  import torch
5
  from ultralytics import YOLO
6
  from sort import Sort
7
+ import gradio as gr
8
 
9
  # Load YOLOv12x model
10
  MODEL_PATH = "yolov12x.pt"
 
14
  TRUCK_CLASS_ID = 7 # "truck"
15
 
16
  # Initialize SORT tracker
17
+ tracker = Sort(max_age=20, min_hits=3, iou_threshold=0.3) # Improved tracking stability
18
 
19
  # Minimum confidence threshold for detection
20
+ CONFIDENCE_THRESHOLD = 0.4 # Adjusted to capture more trucks
21
 
22
  # Distance threshold to avoid duplicate counts
23
  DISTANCE_THRESHOLD = 50
24
 
25
  # Dictionary to define keyword-based time intervals
26
  TIME_INTERVALS = {
27
+ "one": 1, "two": 2, "three": 3, "four": 4, "five": 5,
28
+ "six": 6, "seven": 7, "eight": 8, "nine": 9, "ten": 10, "eleven": 11
 
 
 
 
 
 
 
 
 
29
  }
30
 
 
31
  def determine_time_interval(video_filename):
32
+ """ Determines frame skip interval based on keywords in the filename. """
33
  for keyword, interval in TIME_INTERVALS.items():
34
  if keyword in video_filename:
 
35
  return interval
36
+ return 5 # Default interval
 
 
37
 
38
  def count_unique_trucks(video_path):
39
+ """ Counts unique trucks in a video using YOLOv12x and SORT tracking. """
40
  cap = cv2.VideoCapture(video_path)
41
  if not cap.isOpened():
42
+ return {"Error": "Unable to open video file."}
43
 
44
  unique_truck_ids = set()
45
  truck_history = {}
 
51
  video_filename = os.path.basename(video_path).lower()
52
 
53
  # Determine the dynamic time interval based on filename keywords
54
+ time_interval = determine_time_interval(video_filename)
55
+
56
  # Get total frames in the video
57
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
58
 
59
+ # Dynamically adjust frame skipping based on FPS and movement density
60
+ frame_skip = max(1, min(fps * time_interval // 2, total_frames // 10))
 
 
61
 
62
  frame_count = 0
63
 
 
68
 
69
  frame_count += 1
70
  if frame_count % frame_skip != 0:
71
+ continue # Skip frames based on interval
72
 
73
  # Run YOLOv12x inference
74
  results = model(frame, verbose=False)
 
84
  x1, y1, x2, y2 = map(int, box.xyxy[0]) # Get bounding box
85
  detections.append([x1, y1, x2, y2, confidence])
86
 
87
+ # Convert detections to numpy array for SORT
88
+ detections = np.array(detections) if len(detections) > 0 else np.empty((0, 5))
 
89
 
90
+ # Update SORT tracker
91
+ tracked_objects = tracker.update(detections)
 
92
 
93
+ # Track movement history to avoid duplicate counts
94
+ for obj in tracked_objects:
95
+ truck_id = int(obj[4]) # Unique ID assigned by SORT
96
+ x1, y1, x2, y2 = obj[:4] # Get bounding box coordinates
97
 
98
+ truck_center = (x1 + x2) / 2, (y1 + y2) / 2 # Calculate truck center
 
 
 
99
 
100
+ # Entry-exit zone logic (e.g., bottom 20% of the frame)
101
+ frame_height, frame_width = frame.shape[:2]
102
+ entry_line = frame_height * 0.8 # Bottom 20% of the frame
103
+ exit_line = frame_height * 0.2 # Top 20% of the frame
104
 
105
+ if truck_id not in truck_history:
106
+ # New truck detected
107
+ truck_history[truck_id] = {
108
+ "position": truck_center,
109
+ "crossed_entry": truck_center[1] > entry_line,
110
+ "crossed_exit": False
111
+ }
112
+ continue
113
 
114
+ # If the truck crosses from entry to exit, count it
115
+ if truck_history[truck_id]["crossed_entry"] and truck_center[1] < exit_line:
116
+ truck_history[truck_id]["crossed_exit"] = True
117
+ unique_truck_ids.add(truck_id)
118
 
119
+ cap.release()
120
  return {"Total Unique Trucks": len(unique_truck_ids)}
121
 
122
  # Gradio UI function
 
125
  return "\n".join([f"{key}: {value}" for key, value in result.items()])
126
 
127
  # Define Gradio interface
 
128
  iface = gr.Interface(
129
  fn=analyze_video,
130
  inputs=gr.Video(label="Upload Video"),
 
135
 
136
  # Launch the Gradio app
137
  if __name__ == "__main__":
138
+ iface.launch()