AjaykumarPilla commited on
Commit
e71d142
·
verified ·
1 Parent(s): 657bd9e

Update gully_drs_core/ball_detection.py

Browse files
Files changed (1) hide show
  1. gully_drs_core/ball_detection.py +30 -85
gully_drs_core/ball_detection.py CHANGED
@@ -1,91 +1,36 @@
 
1
  import cv2
2
  import numpy as np
3
- import torch
4
- import logging
5
- import os
6
- from gully_drs_core.replay_utils import generate_replay
7
- from gully_drs_core.video_utils import get_video_properties
8
- from gully_drs_core.model_utils import load_yolo_model
9
 
10
- # Set up logging
11
- logging.basicConfig(level=logging.INFO)
12
- logger = logging.getLogger(__name__)
 
 
 
13
 
14
- # Load YOLOv5 model
15
- model = load_yolo_model()
16
 
17
- # Stump zone coordinates (example, adjust based on video resolution)
18
- STUMP_ZONE = [(200, 400), (300, 400), (300, 600), (200, 600)] # [x1,y1, x2,y2, x3,y3, x4,y4]
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- def process_video(video_path):
21
- try:
22
- # Validate video file
23
- if not os.path.exists(video_path):
24
- return {"status": "error", "error": "Video file not found"}
25
-
26
- # Get video properties
27
- fps, width, height = get_video_properties(video_path)
28
- if fps == 0:
29
- return {"status": "error", "error": "Invalid video file"}
30
-
31
- # Initialize video capture
32
- cap = cv2.VideoCapture(video_path)
33
- ball_positions = []
34
- bounce_point = None
35
- frame_count = 0
36
-
37
- while cap.isOpened():
38
- ret, frame = cap.read()
39
- if not ret:
40
- break
41
-
42
- # Detect ball using YOLOv5
43
- results = model(frame)
44
- detections = results.xyxy[0].cpu().numpy() # [x1, y1, x2, y2, conf, class]
45
-
46
- ball_center = None
47
- for det in detections:
48
- if det[5] == 0: # Assuming class 0 is the ball
49
- x1, y1, x2, y2 = map(int, det[:4])
50
- ball_center = ((x1 + x2) // 2, (y1 + y2) // 2)
51
- ball_positions.append(ball_center)
52
- # Detect bounce point (simplified: assume bounce near ground)
53
- if y1 > height * 0.8 and bounce_point is None:
54
- bounce_point = ball_center
55
- break
56
-
57
- frame_count += 1
58
-
59
- cap.release()
60
-
61
- # Check LBW decision
62
- decision = "Not Out"
63
- for pos in ball_positions:
64
- if is_ball_in_stump_zone(pos):
65
- decision = "Out"
66
- break
67
-
68
- # Calculate speed (pixels per frame to km/h)
69
- speed_kmh = 0
70
- if len(ball_positions) >= 2:
71
- pixel_dist = np.sqrt((ball_positions[-1][0] - ball_positions[-2][0])**2 +
72
- (ball_positions[-1][1] - ball_positions[-2][1])**2)
73
- speed_kmh = (pixel_dist / (1/fps)) * 0.036 # Simplified conversion, adjust scale factor
74
-
75
- # Generate replay video
76
- replay_path = generate_replay(video_path, ball_positions, STUMP_ZONE, decision, speed_kmh, bounce_point)
77
-
78
- return {
79
- "status": "success",
80
- "decision": decision,
81
- "speed_kmh": speed_kmh,
82
- "replay_path": replay_path
83
- }
84
- except Exception as e:
85
- logger.error(f"Error processing video: {str(e)}")
86
- return {"status": "error", "error": str(e)}
87
-
88
- def is_ball_in_stump_zone(ball_center):
89
- x, y = ball_center
90
- x1, y1, x2, y2, x3, y3, x4, y4 = STUMP_ZONE[0] + STUMP_ZONE[1] + STUMP_ZONE[2] + STUMP_ZONE[3]
91
- return (x1 <= x <= x2) and (y1 <= y <= y3)
 
1
+ # gully_drs_core/ball_detection.py
2
  import cv2
3
  import numpy as np
4
+ from .model_utils import load_model
 
 
 
 
 
5
 
6
+ def analyze_video(file_path):
7
+ model = load_model()
8
+ cap = cv2.VideoCapture(file_path)
9
+ fps = cap.get(cv2.CAP_PROP_FPS)
10
+ width = int(cap.get(3))
11
+ height = int(cap.get(4))
12
 
13
+ ball_path = []
14
+ frames = []
15
 
16
+ while True:
17
+ ret, frame = cap.read()
18
+ if not ret:
19
+ break
20
+ results = model(frame)
21
+ for r in results:
22
+ for box in r.boxes:
23
+ cls = int(box.cls[0])
24
+ if cls == 32: # Assuming class 32 = cricket ball
25
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
26
+ cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
27
+ ball_path.append((cx, cy))
28
+ cv2.circle(frame, (cx, cy), 6, (0, 255, 0), -1)
29
+ frames.append(frame)
30
 
31
+ cap.release()
32
+ return {
33
+ "trajectory": ball_path,
34
+ "fps": fps,
35
+ "frames": frames
36
+ }