hb-setosys commited on
Commit
e5e492a
·
verified ·
1 Parent(s): 825941a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -44
app.py CHANGED
@@ -2,12 +2,18 @@ import os
2
  import cv2
3
  import numpy as np
4
  import torch
 
5
  from ultralytics import YOLO
6
  from sort import Sort
7
  import gradio as gr
8
 
 
 
 
9
  # Load YOLOv12x model
10
  MODEL_PATH = "yolov12x.pt"
 
 
11
  model = YOLO(MODEL_PATH)
12
 
13
  # COCO dataset class ID for truck
@@ -17,7 +23,7 @@ TRUCK_CLASS_ID = 7 # "truck"
17
  tracker = Sort()
18
 
19
  # Minimum confidence threshold for detection
20
- CONFIDENCE_THRESHOLD = 0.4 # Lowered for better detection
21
 
22
  # Distance threshold to avoid duplicate counts
23
  DISTANCE_THRESHOLD = 50
@@ -30,41 +36,39 @@ TIME_INTERVALS = {
30
 
31
  def determine_time_interval(video_filename):
32
  """ Determines frame skip interval based on keywords in the filename. """
33
- print(f"Checking filename: {video_filename}") # Debugging
34
  for keyword, interval in TIME_INTERVALS.items():
35
  if keyword in video_filename:
36
- print(f"Matched keyword: {keyword} -> Interval: {interval}") # Debugging
37
  return interval
38
- print("No keyword match, using default interval: 5") # Debugging
39
  return 5 # Default interval
40
 
41
  def count_unique_trucks(video_path):
42
  """ Counts unique trucks in a video using YOLOv12x and SORT tracking. """
 
 
 
43
  cap = cv2.VideoCapture(video_path)
44
  if not cap.isOpened():
45
  return {"Error": "Unable to open video file."}
46
 
47
  unique_truck_ids = set()
48
  truck_history = {}
49
-
50
- # Get FPS of the video
51
- fps = int(cap.get(cv2.CAP_PROP_FPS))
52
-
53
- # Extract filename from the path and convert to lowercase
 
54
  video_filename = os.path.basename(video_path).lower()
55
-
56
- # Determine the dynamic time interval based on filename keywords
57
  time_interval = determine_time_interval(video_filename)
58
 
59
- # Get total frames in the video
60
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
61
-
62
  # Ensure frame_skip does not exceed total frames
63
- frame_skip = min(fps * time_interval, total_frames // 2) # Reduced skipping
64
-
65
  frame_count = 0
66
 
67
- while True:
68
  ret, frame = cap.read()
69
  if not ret:
70
  break # End of video
@@ -79,46 +83,29 @@ def count_unique_trucks(video_path):
79
  detections = []
80
  for result in results:
81
  for box in result.boxes:
82
- class_id = int(box.cls.item()) # Get class ID
83
- confidence = float(box.conf.item()) # Get confidence score
84
 
85
- # Track only trucks
86
  if class_id == TRUCK_CLASS_ID and confidence > CONFIDENCE_THRESHOLD:
87
- x1, y1, x2, y2 = map(int, box.xyxy[0]) # Get bounding box
88
  detections.append([x1, y1, x2, y2, confidence])
89
 
90
- # Debugging: Check detections
91
- print(f"Frame {frame_count}: Detections -> {detections}")
92
-
93
- if len(detections) > 0:
94
- detections = np.array(detections)
95
- tracked_objects = tracker.update(detections)
96
  else:
97
- tracked_objects = [] # Prevent tracker from resetting
98
-
99
- # Debugging: Check tracked objects
100
- print(f"Frame {frame_count}: Tracked Objects -> {tracked_objects}")
101
 
102
  for obj in tracked_objects:
103
- truck_id = int(obj[4]) # Unique ID assigned by SORT
104
- x1, y1, x2, y2 = obj[:4] # Get the bounding box coordinates
105
-
106
- truck_center = (x1 + x2) / 2, (y1 + y2) / 2 # Calculate truck center
107
 
108
- # If truck is already in history, check movement distance
109
  if truck_id in truck_history:
110
  last_position = truck_history[truck_id]["position"]
111
  distance = np.linalg.norm(np.array(truck_center) - np.array(last_position))
112
-
113
  if distance > DISTANCE_THRESHOLD:
114
- unique_truck_ids.add(truck_id) # Add only if moved significantly
115
-
116
  else:
117
- # If truck is not in history, add it
118
- truck_history[truck_id] = {
119
- "frame_count": frame_count,
120
- "position": truck_center
121
- }
122
  unique_truck_ids.add(truck_id)
123
 
124
  cap.release()
@@ -126,6 +113,9 @@ def count_unique_trucks(video_path):
126
 
127
  # Gradio UI function
128
  def analyze_video(video_file):
 
 
 
129
  result = count_unique_trucks(video_file)
130
  return "\n".join([f"{key}: {value}" for key, value in result.items()])
131
 
 
2
  import cv2
3
  import numpy as np
4
  import torch
5
+ import logging
6
  from ultralytics import YOLO
7
  from sort import Sort
8
  import gradio as gr
9
 
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
12
+
13
  # Load YOLOv12x model
14
  MODEL_PATH = "yolov12x.pt"
15
+ if not os.path.exists(MODEL_PATH):
16
+ raise FileNotFoundError(f"Model file '{MODEL_PATH}' not found.")
17
  model = YOLO(MODEL_PATH)
18
 
19
  # COCO dataset class ID for truck
 
23
  tracker = Sort()
24
 
25
  # Minimum confidence threshold for detection
26
+ CONFIDENCE_THRESHOLD = 0.4 # Adjust based on performance
27
 
28
  # Distance threshold to avoid duplicate counts
29
  DISTANCE_THRESHOLD = 50
 
36
 
37
  def determine_time_interval(video_filename):
38
  """ Determines frame skip interval based on keywords in the filename. """
39
+ logging.info(f"Checking filename: {video_filename}")
40
  for keyword, interval in TIME_INTERVALS.items():
41
  if keyword in video_filename:
42
+ logging.info(f"Matched keyword: {keyword} -> Interval: {interval}")
43
  return interval
44
+ logging.info("No keyword match, using default interval: 5")
45
  return 5 # Default interval
46
 
47
  def count_unique_trucks(video_path):
48
  """ Counts unique trucks in a video using YOLOv12x and SORT tracking. """
49
+ if not os.path.exists(video_path):
50
+ return {"Error": "Video file not found."}
51
+
52
  cap = cv2.VideoCapture(video_path)
53
  if not cap.isOpened():
54
  return {"Error": "Unable to open video file."}
55
 
56
  unique_truck_ids = set()
57
  truck_history = {}
58
+
59
+ # Get FPS and total frames
60
+ fps = int(cap.get(cv2.CAP_PROP_FPS)) or 30 # Default to 30 if retrieval fails
61
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 1
62
+
63
+ # Extract filename and determine time interval
64
  video_filename = os.path.basename(video_path).lower()
 
 
65
  time_interval = determine_time_interval(video_filename)
66
 
 
 
 
67
  # Ensure frame_skip does not exceed total frames
68
+ frame_skip = min(fps * time_interval, max(1, total_frames // 2))
 
69
  frame_count = 0
70
 
71
+ while cap.isOpened():
72
  ret, frame = cap.read()
73
  if not ret:
74
  break # End of video
 
83
  detections = []
84
  for result in results:
85
  for box in result.boxes:
86
+ class_id = int(box.cls.item())
87
+ confidence = float(box.conf.item())
88
 
 
89
  if class_id == TRUCK_CLASS_ID and confidence > CONFIDENCE_THRESHOLD:
90
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
91
  detections.append([x1, y1, x2, y2, confidence])
92
 
93
+ if detections:
94
+ tracked_objects = tracker.update(np.array(detections))
 
 
 
 
95
  else:
96
+ tracked_objects = []
 
 
 
97
 
98
  for obj in tracked_objects:
99
+ truck_id = int(obj[4])
100
+ truck_center = ((obj[0] + obj[2]) / 2, (obj[1] + obj[3]) / 2)
 
 
101
 
 
102
  if truck_id in truck_history:
103
  last_position = truck_history[truck_id]["position"]
104
  distance = np.linalg.norm(np.array(truck_center) - np.array(last_position))
 
105
  if distance > DISTANCE_THRESHOLD:
106
+ unique_truck_ids.add(truck_id)
 
107
  else:
108
+ truck_history[truck_id] = {"position": truck_center}
 
 
 
 
109
  unique_truck_ids.add(truck_id)
110
 
111
  cap.release()
 
113
 
114
  # Gradio UI function
115
  def analyze_video(video_file):
116
+ if not video_file:
117
+ return "Error: No video file uploaded."
118
+
119
  result = count_unique_trucks(video_file)
120
  return "\n".join([f"{key}: {value}" for key, value in result.items()])
121