hb-setosys commited on
Commit
55b2656
·
verified ·
1 Parent(s): 1bf8802

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -41
app.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
  import torch
5
  from ultralytics import YOLO
6
  from sort import Sort
 
7
 
8
  # Load YOLOv12x model
9
  MODEL_PATH = "yolov12x.pt"
@@ -16,41 +17,32 @@ TRUCK_CLASS_ID = 7 # "truck"
16
  tracker = Sort()
17
 
18
  # Minimum confidence threshold for detection
19
- CONFIDENCE_THRESHOLD = 0.5
20
 
21
  # Distance threshold to avoid duplicate counts
22
  DISTANCE_THRESHOLD = 50
23
 
24
  # Dictionary to define keyword-based time intervals
25
  TIME_INTERVALS = {
26
- "one": 1,
27
- "two": 2,
28
- "three": 3,
29
- "four": 4,
30
- "five": 5,
31
- "six": 6,
32
- "seven": 7,
33
- "eight": 8,
34
- "nine": 9,
35
- "ten": 10,
36
- "eleven": 11
37
  }
38
 
39
-
40
  def determine_time_interval(video_filename):
 
41
  print(f"Checking filename: {video_filename}") # Debugging
42
  for keyword, interval in TIME_INTERVALS.items():
43
  if keyword in video_filename:
44
  print(f"Matched keyword: {keyword} -> Interval: {interval}") # Debugging
45
  return interval
46
  print("No keyword match, using default interval: 5") # Debugging
47
- return 5 # Default interval if no keyword matches
48
-
49
 
50
  def count_unique_trucks(video_path):
 
51
  cap = cv2.VideoCapture(video_path)
52
  if not cap.isOpened():
53
- return "Error: Unable to open video file."
54
 
55
  unique_truck_ids = set()
56
  truck_history = {}
@@ -63,13 +55,12 @@ def count_unique_trucks(video_path):
63
 
64
  # Determine the dynamic time interval based on filename keywords
65
  time_interval = determine_time_interval(video_filename)
 
66
  # Get total frames in the video
67
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
68
 
69
  # Ensure frame_skip does not exceed total frames
70
- frame_skip = min(fps * time_interval, total_frames)
71
-
72
- #frame_skip = fps * time_interval # Convert time interval to frame count
73
 
74
  frame_count = 0
75
 
@@ -80,7 +71,7 @@ def count_unique_trucks(video_path):
80
 
81
  frame_count += 1
82
  if frame_count % frame_skip != 0:
83
- continue # Skip frames to process only every 5 seconds
84
 
85
  # Run YOLOv12x inference
86
  results = model(frame, verbose=False)
@@ -96,35 +87,41 @@ def count_unique_trucks(video_path):
96
  x1, y1, x2, y2 = map(int, box.xyxy[0]) # Get bounding box
97
  detections.append([x1, y1, x2, y2, confidence])
98
 
 
 
 
99
  if len(detections) > 0:
100
  detections = np.array(detections)
101
  tracked_objects = tracker.update(detections)
 
 
102
 
103
- for obj in tracked_objects:
104
- truck_id = int(obj[4]) # Unique ID assigned by SORT
105
- x1, y1, x2, y2 = obj[:4] # Get the bounding box coordinates
106
 
107
- truck_center = (x1 + x2) / 2, (y1 + y2) / 2 # Calculate the center of the truck
 
 
108
 
109
- # If truck is already in history, check the movement distance
110
- if truck_id in truck_history:
111
- last_position = truck_history[truck_id]["position"]
112
- distance = np.linalg.norm(np.array(truck_center) - np.array(last_position))
113
 
114
- if distance > DISTANCE_THRESHOLD:
115
- # If the truck moved significantly, count as new
116
- unique_truck_ids.add(truck_id)
 
117
 
118
- else:
119
- # If truck is not in history, add it
120
- truck_history[truck_id] = {
121
- "frame_count": frame_count,
122
- "position": truck_center
123
- }
124
- unique_truck_ids.add(truck_id)
125
 
126
- cap.release()
 
 
 
 
 
 
127
 
 
128
  return {"Total Unique Trucks": len(unique_truck_ids)}
129
 
130
  # Gradio UI function
@@ -133,7 +130,6 @@ def analyze_video(video_file):
133
  return "\n".join([f"{key}: {value}" for key, value in result.items()])
134
 
135
  # Define Gradio interface
136
- import gradio as gr
137
  iface = gr.Interface(
138
  fn=analyze_video,
139
  inputs=gr.Video(label="Upload Video"),
@@ -145,4 +141,3 @@ iface = gr.Interface(
145
  # Launch the Gradio app
146
  if __name__ == "__main__":
147
  iface.launch()
148
-
 
4
  import torch
5
  from ultralytics import YOLO
6
  from sort import Sort
7
+ import gradio as gr
8
 
9
  # Load YOLOv12x model
10
  MODEL_PATH = "yolov12x.pt"
 
17
  tracker = Sort()
18
 
19
  # Minimum confidence threshold for detection
20
+ CONFIDENCE_THRESHOLD = 0.4 # Lowered for better detection
21
 
22
  # Distance threshold to avoid duplicate counts
23
  DISTANCE_THRESHOLD = 50
24
 
25
  # Dictionary to define keyword-based time intervals
26
  TIME_INTERVALS = {
27
+ "one": 1, "two": 2, "three": 3, "four": 4, "five": 5,
28
+ "six": 6, "seven": 7, "eight": 8, "nine": 9, "ten": 10, "eleven": 11
 
 
 
 
 
 
 
 
 
29
  }
30
 
 
31
  def determine_time_interval(video_filename):
32
+ """ Determines frame skip interval based on keywords in the filename. """
33
  print(f"Checking filename: {video_filename}") # Debugging
34
  for keyword, interval in TIME_INTERVALS.items():
35
  if keyword in video_filename:
36
  print(f"Matched keyword: {keyword} -> Interval: {interval}") # Debugging
37
  return interval
38
  print("No keyword match, using default interval: 5") # Debugging
39
+ return 5 # Default interval
 
40
 
41
  def count_unique_trucks(video_path):
42
+ """ Counts unique trucks in a video using YOLOv12x and SORT tracking. """
43
  cap = cv2.VideoCapture(video_path)
44
  if not cap.isOpened():
45
+ return {"Error": "Unable to open video file."}
46
 
47
  unique_truck_ids = set()
48
  truck_history = {}
 
55
 
56
  # Determine the dynamic time interval based on filename keywords
57
  time_interval = determine_time_interval(video_filename)
58
+
59
  # Get total frames in the video
60
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
61
 
62
  # Ensure frame_skip does not exceed total frames
63
+ frame_skip = min(fps * time_interval, total_frames // 2) # Reduced skipping
 
 
64
 
65
  frame_count = 0
66
 
 
71
 
72
  frame_count += 1
73
  if frame_count % frame_skip != 0:
74
+ continue # Skip frames based on interval
75
 
76
  # Run YOLOv12x inference
77
  results = model(frame, verbose=False)
 
87
  x1, y1, x2, y2 = map(int, box.xyxy[0]) # Get bounding box
88
  detections.append([x1, y1, x2, y2, confidence])
89
 
90
+ # Debugging: Check detections
91
+ print(f"Frame {frame_count}: Detections -> {detections}")
92
+
93
  if len(detections) > 0:
94
  detections = np.array(detections)
95
  tracked_objects = tracker.update(detections)
96
+ else:
97
+ tracked_objects = [] # Prevent tracker from resetting
98
 
99
+ # Debugging: Check tracked objects
100
+ print(f"Frame {frame_count}: Tracked Objects -> {tracked_objects}")
 
101
 
102
+ for obj in tracked_objects:
103
+ truck_id = int(obj[4]) # Unique ID assigned by SORT
104
+ x1, y1, x2, y2 = obj[:4] # Get the bounding box coordinates
105
 
106
+ truck_center = (x1 + x2) / 2, (y1 + y2) / 2 # Calculate truck center
 
 
 
107
 
108
+ # If truck is already in history, check movement distance
109
+ if truck_id in truck_history:
110
+ last_position = truck_history[truck_id]["position"]
111
+ distance = np.linalg.norm(np.array(truck_center) - np.array(last_position))
112
 
113
+ if distance > DISTANCE_THRESHOLD:
114
+ unique_truck_ids.add(truck_id) # Add only if moved significantly
 
 
 
 
 
115
 
116
+ else:
117
+ # If truck is not in history, add it
118
+ truck_history[truck_id] = {
119
+ "frame_count": frame_count,
120
+ "position": truck_center
121
+ }
122
+ unique_truck_ids.add(truck_id)
123
 
124
+ cap.release()
125
  return {"Total Unique Trucks": len(unique_truck_ids)}
126
 
127
  # Gradio UI function
 
130
  return "\n".join([f"{key}: {value}" for key, value in result.items()])
131
 
132
  # Define Gradio interface
 
133
  iface = gr.Interface(
134
  fn=analyze_video,
135
  inputs=gr.Video(label="Upload Video"),
 
141
  # Launch the Gradio app
142
  if __name__ == "__main__":
143
  iface.launch()