AjaykumarPilla commited on
Commit
c0896c8
·
verified ·
1 Parent(s): e367638

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -28
app.py CHANGED
@@ -6,6 +6,7 @@ import plotly.graph_objects as go
6
  import torch
7
  import gradio as gr
8
  import os
 
9
  from scipy.optimize import curve_fit
10
  import sys
11
 
@@ -24,6 +25,9 @@ STUMP_WIDTH = 0.2286 # Stump width (including bails)
24
 
25
  # Model input size (adjust if best.pt was trained with a different size)
26
  MODEL_INPUT_SIZE = (640, 640) # (height, width)
 
 
 
27
 
28
  # Load model
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -40,45 +44,70 @@ def process_video(video_path):
40
  frame_numbers = []
41
  bounce_frame = None
42
  bounce_point = None
 
 
 
43
 
 
44
  while cap.isOpened():
45
  frame_num = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
46
  ret, frame = cap.read()
47
  if not ret:
48
  break
49
 
 
 
 
 
 
50
  # Resize frame to model input size
51
  frame = cv2.resize(frame, MODEL_INPUT_SIZE, interpolation=cv2.INTER_AREA)
52
-
53
- # Preprocess frame for YOLOv5
54
- img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
55
- img = torch.from_numpy(img).to(device).float() / 255.0
56
- img = img.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
57
-
58
- # Run inference
59
- with torch.no_grad():
60
- pred = model(img)[0]
61
- pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45)
62
-
63
- # Process detections
64
- for det in pred:
65
- if det is not None and len(det):
66
- det = xywh2xyxy(det) # Convert to [x1, y1, x2, y2]
67
- for *xyxy, conf, cls in det:
68
- x_center = (xyxy[0] + xyxy[2]) / 2
69
- y_center = (xyxy[1] + xyxy[3]) / 2
70
- # Scale coordinates back to original frame size
71
- x_center = x_center * frame_width / MODEL_INPUT_SIZE[1]
72
- y_center = y_center * frame_height / MODEL_INPUT_SIZE[0]
73
- positions.append((x_center.item(), y_center.item()))
74
- frame_numbers.append(frame_num)
75
-
76
- # Detect bounce (lowest y_center point)
77
- if bounce_frame is None or y_center > positions[bounce_frame][1]:
78
- bounce_frame = len(frame_numbers) - 1
79
- bounce_point = (x_center.item(), y_center.item())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  cap.release()
 
82
  return positions, frame_numbers, bounce_point, frame_rate, frame_width, frame_height
83
 
84
  # Polynomial function for trajectory fitting
 
6
  import torch
7
  import gradio as gr
8
  import os
9
+ import time
10
  from scipy.optimize import curve_fit
11
  import sys
12
 
 
25
 
26
  # Model input size (adjust if best.pt was trained with a different size)
27
  MODEL_INPUT_SIZE = (640, 640) # (height, width)
28
+ FRAME_SKIP = 2 # Process every 2nd frame
29
+ MIN_DETECTIONS = 10 # Stop after 10 detections
30
+ BATCH_SIZE = 4 # Process 4 frames at a time
31
 
32
  # Load model
33
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
44
  frame_numbers = []
45
  bounce_frame = None
46
  bounce_point = None
47
+ batch_frames = []
48
+ batch_frame_nums = []
49
+ frame_count = 0
50
 
51
+ start_time = time.time()
52
  while cap.isOpened():
53
  frame_num = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
54
  ret, frame = cap.read()
55
  if not ret:
56
  break
57
 
58
+ # Skip frames
59
+ if frame_count % FRAME_SKIP != 0:
60
+ frame_count += 1
61
+ continue
62
+
63
  # Resize frame to model input size
64
  frame = cv2.resize(frame, MODEL_INPUT_SIZE, interpolation=cv2.INTER_AREA)
65
+ batch_frames.append(frame)
66
+ batch_frame_nums.append(frame_num)
67
+ frame_count += 1
68
+
69
+ # Process batch when full or at end
70
+ if len(batch_frames) == BATCH_SIZE or not ret:
71
+ # Preprocess batch
72
+ batch = [cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in batch_frames]
73
+ batch = np.stack(batch) # [batch_size, H, W, 3]
74
+ batch = torch.from_numpy(batch).to(device).float() / 255.0
75
+ batch = batch.permute(0, 3, 1, 2) # [batch_size, 3, H, W]
76
+
77
+ # Run inference
78
+ frame_start_time = time.time()
79
+ with torch.no_grad():
80
+ pred = model(batch)[0]
81
+ pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45)
82
+ print(f"Batch inference time: {time.time() - frame_start_time:.2f}s for {len(batch_frames)} frames")
83
+
84
+ # Process detections
85
+ for i, det in enumerate(pred):
86
+ if det is not None and len(det):
87
+ det = xywh2xyxy(det) # Convert to [x1, y1, x2, y2]
88
+ for *xyxy, conf, cls in det:
89
+ x_center = (xyxy[0] + xyxy[2]) / 2
90
+ y_center = (xyxy[1] + xyxy[3]) / 2
91
+ # Scale coordinates back to original frame size
92
+ x_center = x_center * frame_width / MODEL_INPUT_SIZE[1]
93
+ y_center = y_center * frame_height / MODEL_INPUT_SIZE[0]
94
+ positions.append((x_center.item(), y_center.item()))
95
+ frame_numbers.append(batch_frame_nums[i])
96
+
97
+ # Detect bounce (lowest y_center point)
98
+ if bounce_frame is None or y_center > positions[bounce_frame][1]:
99
+ bounce_frame = len(frame_numbers) - 1
100
+ bounce_point = (x_center.item(), y_center.item())
101
+
102
+ batch_frames = []
103
+ batch_frame_nums = []
104
+
105
+ # Early termination
106
+ if len(positions) >= MIN_DETECTIONS:
107
+ break
108
 
109
  cap.release()
110
+ print(f"Total video processing time: {time.time() - start_time:.2f}s")
111
  return positions, frame_numbers, bounce_point, frame_rate, frame_width, frame_height
112
 
113
  # Polynomial function for trajectory fitting