MehdiH7 commited on
Commit
1d7d324
·
verified ·
1 Parent(s): f2b30de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -92
app.py CHANGED
@@ -3,12 +3,9 @@ import gradio as gr
3
  import torch
4
  from torchvision import transforms, models
5
  import cv2
6
- from PIL import Image
7
  import numpy as np
 
8
  from ultralytics import YOLO
9
- import os
10
-
11
- DEMO_VIDEO = "hockey_sample_5s.mp4"
12
 
13
  def load_models():
14
  # Initialize YOLO
@@ -23,14 +20,21 @@ def load_models():
23
 
24
  return yolo_model, squeezenet_model
25
 
26
- def process_video(video_path):
27
- if video_path is None:
28
- video_path = DEMO_VIDEO
29
-
 
 
 
 
 
 
 
30
  # Initialize models
31
  yolo_model, squeezenet_model = load_models()
32
 
33
- # Class labels
34
  class_labels = [
35
  "Bottom", "Bottom_Left", "Bottom_Right", "Left",
36
  "Right", "Top", "Top_Left", "Top_Right"
@@ -43,106 +47,83 @@ def process_video(video_path):
43
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
44
  ])
45
 
46
- # Read video
47
- cap = cv2.VideoCapture(video_path)
48
- fps = int(cap.get(cv2.CAP_PROP_FPS))
49
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
50
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
51
 
52
- # Prepare video writer
53
- output_path = "output_video.mp4"
54
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
55
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
56
-
57
- while cap.isOpened():
58
- ret, frame = cap.read()
59
- if not ret:
60
- break
61
-
62
- # Run YOLO detection
63
- results = yolo_model(frame)
64
 
65
- # Process each detection
66
- for box in results[0].boxes:
67
- xyxy = box.xyxy[0].cpu().numpy()
68
- conf = float(box.conf[0].cpu().numpy())
69
- cls = int(box.cls[0].cpu().numpy())
70
 
71
- # Process only if it's a player (class 4) and confidence is above threshold
72
- if cls == 4 and conf > 0.5:
73
- x1, y1, x2, y2 = map(int, xyxy)
74
-
75
- # Crop and process for direction classification
76
- if x2 > x1 and y2 > y1:
77
- cropped_array = frame[y1:y2, x1:x2]
78
- if cropped_array.size > 0:
79
- cropped_image = Image.fromarray(cv2.cvtColor(cropped_array, cv2.COLOR_BGR2RGB))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- # Predict direction
82
- image_tensor = transform(cropped_image).unsqueeze(0)
83
- with torch.no_grad():
84
- output = squeezenet_model(image_tensor)
85
- direction_class = torch.argmax(output, dim=1).item()
86
-
87
- # Draw annotations
88
- cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
89
- cv2.putText(frame, f"{conf:.2f}", (x1, y1-10),
90
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
91
-
92
- # Draw direction arrow
93
- center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
94
- arrow_length = 50
95
- direction = class_labels[direction_class]
96
-
97
- # Calculate arrow endpoint
98
- end_x, end_y = center_x, center_y
99
- if "Top" in direction:
100
- end_y = center_y - arrow_length
101
- elif "Bottom" in direction:
102
- end_y = center_y + arrow_length
103
- if "Left" in direction:
104
- end_x = center_x - arrow_length
105
- elif "Right" in direction:
106
- end_x = center_x + arrow_length
107
-
108
- cv2.arrowedLine(frame, (center_x, center_y), (end_x, end_y),
109
- (0, 0, 255), 2, tipLength=0.3)
110
-
111
- out.write(frame)
112
 
113
- cap.release()
114
- out.release()
115
-
116
- return output_path
117
-
118
- def example_video():
119
- return DEMO_VIDEO
120
 
121
  # Create Gradio interface
122
  def gradio_interface():
123
  with gr.Blocks() as iface:
124
  gr.Markdown("# Player Direction Detection")
125
- gr.Markdown("Upload a video or use the demo video to detect players and their movement directions")
126
 
127
  with gr.Row():
128
  with gr.Column():
129
- input_video = gr.Video(label="Input Video")
130
- demo_button = gr.Button("Use Demo Video")
131
-
132
  with gr.Column():
133
- output_video = gr.Video(label="Output Video")
134
 
135
- # Handle demo button click
136
- demo_button.click(
137
- fn=example_video,
138
- outputs=input_video
 
139
  )
140
 
141
- # Handle video processing
142
- input_video.change(
143
- fn=process_video,
144
- inputs=[input_video],
145
- outputs=[output_video]
 
 
146
  )
147
 
148
  return iface
 
3
  import torch
4
  from torchvision import transforms, models
5
  import cv2
 
6
  import numpy as np
7
+ from PIL import Image
8
  from ultralytics import YOLO
 
 
 
9
 
10
  def load_models():
11
  # Initialize YOLO
 
20
 
21
  return yolo_model, squeezenet_model
22
 
23
+ def process_image(input_image):
24
+ if input_image is None:
25
+ return None
26
+
27
+ # Convert to numpy array if needed
28
+ if isinstance(input_image, str):
29
+ image = cv2.imread(input_image)
30
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
31
+ else:
32
+ image = input_image.copy()
33
+
34
  # Initialize models
35
  yolo_model, squeezenet_model = load_models()
36
 
37
+ # Class labels for direction
38
  class_labels = [
39
  "Bottom", "Bottom_Left", "Bottom_Right", "Left",
40
  "Right", "Top", "Top_Left", "Top_Right"
 
47
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
48
  ])
49
 
50
+ # Run YOLO detection
51
+ results = yolo_model(image)
 
 
 
52
 
53
+ # Process each detection
54
+ for box in results[0].boxes:
55
+ xyxy = box.xyxy[0].cpu().numpy()
56
+ conf = float(box.conf[0].cpu().numpy())
57
+ cls = int(box.cls[0].cpu().numpy())
 
 
 
 
 
 
 
58
 
59
+ # Process only if it's a player (class 4) and confidence is above threshold
60
+ if cls == 4 and conf > 0.5:
61
+ x1, y1, x2, y2 = map(int, xyxy)
 
 
62
 
63
+ # Crop and process for direction classification
64
+ if x2 > x1 and y2 > y1:
65
+ cropped_array = image[y1:y2, x1:x2]
66
+ if cropped_array.size > 0:
67
+ cropped_image = Image.fromarray(cropped_array)
68
+
69
+ # Predict direction
70
+ image_tensor = transform(cropped_image).unsqueeze(0)
71
+ with torch.no_grad():
72
+ output = squeezenet_model(image_tensor)
73
+ direction_class = torch.argmax(output, dim=1).item()
74
+
75
+ # Draw annotations
76
+ cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
77
+ cv2.putText(image, f"{conf:.2f}", (x1, y1-10),
78
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
79
+
80
+ # Draw direction arrow
81
+ center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
82
+ arrow_length = 50
83
+ direction = class_labels[direction_class]
84
+
85
+ # Calculate arrow endpoint
86
+ end_x, end_y = center_x, center_y
87
+ if "Top" in direction:
88
+ end_y = center_y - arrow_length
89
+ elif "Bottom" in direction:
90
+ end_y = center_y + arrow_length
91
+ if "Left" in direction:
92
+ end_x = center_x - arrow_length
93
+ elif "Right" in direction:
94
+ end_x = center_x + arrow_length
95
 
96
+ cv2.arrowedLine(image, (center_x, center_y), (end_x, end_y),
97
+ (0, 0, 255), 2, tipLength=0.3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ return image
 
 
 
 
 
 
100
 
101
  # Create Gradio interface
102
  def gradio_interface():
103
  with gr.Blocks() as iface:
104
  gr.Markdown("# Player Direction Detection")
105
+ gr.Markdown("Upload an image to detect players and their movement directions")
106
 
107
  with gr.Row():
108
  with gr.Column():
109
+ input_image = gr.Image(label="Input Image", type="numpy")
 
 
110
  with gr.Column():
111
+ output_image = gr.Image(label="Output Image")
112
 
113
+ # Handle image processing
114
+ input_image.change(
115
+ fn=process_image,
116
+ inputs=[input_image],
117
+ outputs=[output_image]
118
  )
119
 
120
+ # Add example images if you have them
121
+ gr.Examples(
122
+ examples=["example-1.jpg", "example-2.jpg"],
123
+ inputs=input_image,
124
+ outputs=output_image,
125
+ fn=process_image,
126
+ cache_examples=True
127
  )
128
 
129
  return iface