hb-setosys commited on
Commit
9b6b6b1
·
verified ·
1 Parent(s): e5e492a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -44
app.py CHANGED
@@ -2,28 +2,22 @@ import os
2
  import cv2
3
  import numpy as np
4
  import torch
5
- import logging
6
  from ultralytics import YOLO
7
  from sort import Sort
8
  import gradio as gr
9
 
10
- # Configure logging
11
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
12
-
13
  # Load YOLOv12x model
14
  MODEL_PATH = "yolov12x.pt"
15
- if not os.path.exists(MODEL_PATH):
16
- raise FileNotFoundError(f"Model file '{MODEL_PATH}' not found.")
17
  model = YOLO(MODEL_PATH)
18
 
19
  # COCO dataset class ID for truck
20
  TRUCK_CLASS_ID = 7 # "truck"
21
 
22
  # Initialize SORT tracker
23
- tracker = Sort()
24
 
25
  # Minimum confidence threshold for detection
26
- CONFIDENCE_THRESHOLD = 0.4 # Adjust based on performance
27
 
28
  # Distance threshold to avoid duplicate counts
29
  DISTANCE_THRESHOLD = 50
@@ -36,39 +30,38 @@ TIME_INTERVALS = {
36
 
37
  def determine_time_interval(video_filename):
38
  """ Determines frame skip interval based on keywords in the filename. """
39
- logging.info(f"Checking filename: {video_filename}")
40
  for keyword, interval in TIME_INTERVALS.items():
41
  if keyword in video_filename:
42
- logging.info(f"Matched keyword: {keyword} -> Interval: {interval}")
43
  return interval
44
- logging.info("No keyword match, using default interval: 5")
45
  return 5 # Default interval
46
 
47
  def count_unique_trucks(video_path):
48
  """ Counts unique trucks in a video using YOLOv12x and SORT tracking. """
49
- if not os.path.exists(video_path):
50
- return {"Error": "Video file not found."}
51
-
52
  cap = cv2.VideoCapture(video_path)
53
  if not cap.isOpened():
54
  return {"Error": "Unable to open video file."}
55
 
56
  unique_truck_ids = set()
57
  truck_history = {}
58
-
59
- # Get FPS and total frames
60
- fps = int(cap.get(cv2.CAP_PROP_FPS)) or 30 # Default to 30 if retrieval fails
61
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 1
62
-
63
- # Extract filename and determine time interval
64
  video_filename = os.path.basename(video_path).lower()
65
- time_interval = determine_time_interval(video_filename)
66
 
67
- # Ensure frame_skip does not exceed total frames
68
- frame_skip = min(fps * time_interval, max(1, total_frames // 2))
 
 
 
 
 
 
 
69
  frame_count = 0
70
 
71
- while cap.isOpened():
72
  ret, frame = cap.read()
73
  if not ret:
74
  break # End of video
@@ -83,29 +76,44 @@ def count_unique_trucks(video_path):
83
  detections = []
84
  for result in results:
85
  for box in result.boxes:
86
- class_id = int(box.cls.item())
87
- confidence = float(box.conf.item())
88
 
 
89
  if class_id == TRUCK_CLASS_ID and confidence > CONFIDENCE_THRESHOLD:
90
- x1, y1, x2, y2 = map(int, box.xyxy[0])
91
  detections.append([x1, y1, x2, y2, confidence])
92
 
93
- if detections:
94
- tracked_objects = tracker.update(np.array(detections))
95
- else:
96
- tracked_objects = []
 
97
 
 
98
  for obj in tracked_objects:
99
- truck_id = int(obj[4])
100
- truck_center = ((obj[0] + obj[2]) / 2, (obj[1] + obj[3]) / 2)
101
-
102
- if truck_id in truck_history:
103
- last_position = truck_history[truck_id]["position"]
104
- distance = np.linalg.norm(np.array(truck_center) - np.array(last_position))
105
- if distance > DISTANCE_THRESHOLD:
106
- unique_truck_ids.add(truck_id)
107
- else:
108
- truck_history[truck_id] = {"position": truck_center}
 
 
 
 
 
 
 
 
 
 
 
 
109
  unique_truck_ids.add(truck_id)
110
 
111
  cap.release()
@@ -113,9 +121,6 @@ def count_unique_trucks(video_path):
113
 
114
  # Gradio UI function
115
  def analyze_video(video_file):
116
- if not video_file:
117
- return "Error: No video file uploaded."
118
-
119
  result = count_unique_trucks(video_file)
120
  return "\n".join([f"{key}: {value}" for key, value in result.items()])
121
 
 
2
  import cv2
3
  import numpy as np
4
  import torch
 
5
  from ultralytics import YOLO
6
  from sort import Sort
7
  import gradio as gr
8
 
 
 
 
9
  # Load YOLOv12x model
10
  MODEL_PATH = "yolov12x.pt"
 
 
11
  model = YOLO(MODEL_PATH)
12
 
13
  # COCO dataset class ID for truck
14
  TRUCK_CLASS_ID = 7 # "truck"
15
 
16
  # Initialize SORT tracker
17
+ tracker = Sort(max_age=20, min_hits=3, iou_threshold=0.3) # Improved tracking stability
18
 
19
  # Minimum confidence threshold for detection
20
+ CONFIDENCE_THRESHOLD = 0.4 # Adjusted to capture more trucks
21
 
22
  # Distance threshold to avoid duplicate counts
23
  DISTANCE_THRESHOLD = 50
 
30
 
31
  def determine_time_interval(video_filename):
32
  """ Determines frame skip interval based on keywords in the filename. """
 
33
  for keyword, interval in TIME_INTERVALS.items():
34
  if keyword in video_filename:
 
35
  return interval
 
36
  return 5 # Default interval
37
 
38
  def count_unique_trucks(video_path):
39
  """ Counts unique trucks in a video using YOLOv12x and SORT tracking. """
 
 
 
40
  cap = cv2.VideoCapture(video_path)
41
  if not cap.isOpened():
42
  return {"Error": "Unable to open video file."}
43
 
44
  unique_truck_ids = set()
45
  truck_history = {}
46
+
47
+ # Get FPS of the video
48
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
49
+
50
+ # Extract filename from the path and convert to lowercase
 
51
  video_filename = os.path.basename(video_path).lower()
 
52
 
53
+ # Determine the dynamic time interval based on filename keywords
54
+ #time_interval = determine_time_interval(video_filename)
55
+ time_interval = 7
56
+ # Get total frames in the video
57
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
58
+
59
+ # Dynamically adjust frame skipping based on FPS and movement density
60
+ frame_skip = max(1, min(fps * time_interval // 2, total_frames // 10))
61
+
62
  frame_count = 0
63
 
64
+ while True:
65
  ret, frame = cap.read()
66
  if not ret:
67
  break # End of video
 
76
  detections = []
77
  for result in results:
78
  for box in result.boxes:
79
+ class_id = int(box.cls.item()) # Get class ID
80
+ confidence = float(box.conf.item()) # Get confidence score
81
 
82
+ # Track only trucks
83
  if class_id == TRUCK_CLASS_ID and confidence > CONFIDENCE_THRESHOLD:
84
+ x1, y1, x2, y2 = map(int, box.xyxy[0]) # Get bounding box
85
  detections.append([x1, y1, x2, y2, confidence])
86
 
87
+ # Convert detections to numpy array for SORT
88
+ detections = np.array(detections) if len(detections) > 0 else np.empty((0, 5))
89
+
90
+ # Update SORT tracker
91
+ tracked_objects = tracker.update(detections)
92
 
93
+ # Track movement history to avoid duplicate counts
94
  for obj in tracked_objects:
95
+ truck_id = int(obj[4]) # Unique ID assigned by SORT
96
+ x1, y1, x2, y2 = obj[:4] # Get bounding box coordinates
97
+
98
+ truck_center = (x1 + x2) / 2, (y1 + y2) / 2 # Calculate truck center
99
+
100
+ # Entry-exit zone logic (e.g., bottom 20% of the frame)
101
+ frame_height, frame_width = frame.shape[:2]
102
+ entry_line = frame_height * 0.8 # Bottom 20% of the frame
103
+ exit_line = frame_height * 0.2 # Top 20% of the frame
104
+
105
+ if truck_id not in truck_history:
106
+ # New truck detected
107
+ truck_history[truck_id] = {
108
+ "position": truck_center,
109
+ "crossed_entry": truck_center[1] > entry_line,
110
+ "crossed_exit": False
111
+ }
112
+ continue
113
+
114
+ # If the truck crosses from entry to exit, count it
115
+ if truck_history[truck_id]["crossed_entry"] and truck_center[1] < exit_line:
116
+ truck_history[truck_id]["crossed_exit"] = True
117
  unique_truck_ids.add(truck_id)
118
 
119
  cap.release()
 
121
 
122
  # Gradio UI function
123
  def analyze_video(video_file):
 
 
 
124
  result = count_unique_trucks(video_file)
125
  return "\n".join([f"{key}: {value}" for key, value in result.items()])
126