AjaykumarPilla commited on
Commit
41c03cf
·
verified ·
1 Parent(s): a8a412d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import cv2
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from yolov5 import YOLOv5
7
+
8
+ # Load YOLOv5 model (best.pt)
9
+ model = YOLOv5("best.pt") # Adjust the path to your model file
10
+
11
+ # Function to process the video and calculate ball trajectory, speed, and visualize the pitch
12
+ def process_video(video_file):
13
+ # Load video file using OpenCV
14
+ video = cv2.VideoCapture(video_file.name)
15
+ ball_positions = []
16
+ speed_data = []
17
+
18
+ frame_count = 0
19
+ last_position = None
20
+
21
+ while video.isOpened():
22
+ ret, frame = video.read()
23
+ if not ret:
24
+ break
25
+
26
+ frame_count += 1
27
+
28
+ # Run YOLOv5 model on the frame to detect ball
29
+ results = model(frame)
30
+
31
+ # Extract the ball position (assuming class 0 = ball)
32
+ ball_detections = results.pandas().xywh
33
+ ball = ball_detections[ball_detections['class'] == 0] # class 0 is ball, adjust as needed
34
+
35
+ if not ball.empty:
36
+ ball_x = ball.iloc[0]['xmin'] + (ball.iloc[0]['xmax'] - ball.iloc[0]['xmin']) / 2
37
+ ball_y = ball.iloc[0]['ymin'] + (ball.iloc[0]['ymax'] - ball.iloc[0]['ymin']) / 2
38
+ ball_positions.append((frame_count, ball_x, ball_y)) # Track position in each frame
39
+
40
+ if last_position is not None:
41
+ # Calculate speed based on pixel displacement between frames
42
+ distance = np.sqrt((ball_x - last_position[1]) ** 2 + (ball_y - last_position[2]) ** 2)
43
+ fps = video.get(cv2.CAP_PROP_FPS) # Frames per second of the video
44
+ speed = distance * fps # Speed = distance / time (time between frames is 1/fps)
45
+ speed_data.append(speed)
46
+
47
+ last_position = (frame_count, ball_x, ball_y) # Update last position
48
+
49
+ video.release()
50
+
51
+ # Ball trajectory plot
52
+ plot_trajectory(ball_positions)
53
+
54
+ # Return results
55
+ avg_speed = np.mean(speed_data) if speed_data else 0
56
+ return f"Average Ball Speed: {avg_speed:.2f} pixels per second"
57
+
58
+ # Function to plot ball trajectory using matplotlib
59
+ def plot_trajectory(ball_positions):
60
+ x_positions = [pos[1] for pos in ball_positions]
61
+ y_positions = [pos[2] for pos in ball_positions]
62
+
63
+ plt.figure(figsize=(10, 6))
64
+ plt.plot(x_positions, y_positions, label="Ball Trajectory", color='b')
65
+ plt.title("Ball Trajectory on Pitch")
66
+ plt.xlabel("X Position (pitch width)")
67
+ plt.ylabel("Y Position (pitch length)")
68
+ plt.grid(True)
69
+ plt.legend()
70
+ plt.show()
71
+
72
+ # Gradio interface for the app
73
+ iface = gr.Interface(
74
+ fn=process_video, # Function to call when video is uploaded
75
+ inputs=gr.inputs.File(label="Upload a Video File"), # File input (video)
76
+ outputs="text", # Output the result as text
77
+ live=True # Keep the interface live
78
+ )
79
+
80
+ iface.launch(debug=True)