MehdiH7 commited on
Commit
68f0979
·
verified ·
1 Parent(s): 5a9905d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -0
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ from ultralytics import YOLO
5
+ import torch
6
+ import os
7
+
8
+ def create_gradient_circle(radius, color=(0, 255, 0), alpha=0.7):
9
+ """Create a gradient circle with a glowing effect"""
10
+ size = radius * 2 + 1
11
+ center = (radius, radius)
12
+ circle_img = np.zeros((size, size, 4), dtype=np.uint8)
13
+
14
+ for r in range(radius + 1):
15
+ alpha_r = alpha * (1 - (r/radius)**2)
16
+ cv2.circle(circle_img, center, r, (*color, int(255 * alpha_r)), -1)
17
+
18
+ return circle_img
19
+
20
+ def draw_advanced_keypoint(frame, center, keypoint_id, conf, radius=12):
21
+ """Draw an advanced technical keypoint with class ID"""
22
+ x, y = center
23
+ color = (0, 255, 0)
24
+
25
+ gradient = create_gradient_circle(radius + 4, color)
26
+
27
+ gx1, gy1 = max(0, x-radius-4), max(0, y-radius-4)
28
+ gx2, gy2 = min(frame.shape[1], x+radius+5), min(frame.shape[0], y+radius+5)
29
+
30
+ if gx1 < gx2 and gy1 < gy2:
31
+ roi = frame[gy1:gy2, gx1:gx2]
32
+ gradient_roi = gradient[:gy2-gy1, :gx2-gx1]
33
+ alpha = gradient_roi[:, :, 3:4] / 255.0
34
+ roi[:] = roi * (1 - alpha) + gradient_roi[:, :, :3] * alpha
35
+
36
+ cv2.circle(frame, center, radius, color, 2)
37
+ cv2.circle(frame, center, radius-2, color, -1)
38
+ cv2.circle(frame, center, radius-1, (255, 255, 255), 1)
39
+
40
+ label_text = f"{keypoint_id}:{conf:.2f}"
41
+ font = cv2.FONT_HERSHEY_SIMPLEX
42
+ font_scale = 1.0
43
+ thickness = 2
44
+
45
+ (text_w, text_h), baseline = cv2.getTextSize(label_text, font, font_scale, thickness)
46
+
47
+ margin = 2
48
+ bg_pts = np.array([
49
+ [x - text_w//2 - margin, y - radius - text_h - margin*2],
50
+ [x + text_w//2 + margin, y - radius - text_h - margin*2],
51
+ [x + text_w//2 + margin, y - radius - margin],
52
+ [x + margin, y - radius + margin],
53
+ [x - margin, y - radius + margin],
54
+ [x - text_w//2 - margin, y - radius - margin],
55
+ ], np.int32)
56
+
57
+ cv2.fillPoly(frame, [bg_pts], (0, 0, 0))
58
+ cv2.polylines(frame, [bg_pts], True, color, 1)
59
+
60
+ cv2.putText(frame, label_text,
61
+ (x - text_w//2, y - radius - margin*2),
62
+ font, font_scale, (255, 255, 255), thickness)
63
+
64
+ def process_image(input_image, conf_threshold=0.5):
65
+ """Process image for pose estimation"""
66
+ # Load model
67
+ model_path = "HockeyRink.pt"
68
+ model = YOLO(model_path)
69
+
70
+ # Convert Gradio image to CV2 format if necessary
71
+ if isinstance(input_image, str):
72
+ frame = cv2.imread(input_image)
73
+ else:
74
+ frame = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)
75
+
76
+ # Make prediction
77
+ results = model.predict(frame, conf=conf_threshold)
78
+
79
+ # Create copy for annotation
80
+ annotated_frame = frame.copy()
81
+
82
+ # Process each detection
83
+ for result in results:
84
+ if result.keypoints is not None:
85
+ keypoints = result.keypoints.data[0]
86
+
87
+ # Draw class label
88
+ if hasattr(result, 'boxes') and len(result.boxes.cls) > 0:
89
+ class_id = int(result.boxes.cls[0])
90
+ class_conf = float(result.boxes.conf[0])
91
+
92
+ if len(keypoints) > 0:
93
+ text_x = int(min(kp[0] for kp in keypoints))
94
+ text_y = int(min(kp[1] for kp in keypoints)) - 40
95
+
96
+ main_label = f"Class ID:{class_id} ({class_conf:.2f})"
97
+ font = cv2.FONT_HERSHEY_SIMPLEX
98
+ font_scale = 1.2
99
+ thickness = 2
100
+
101
+ (text_w, text_h), baseline = cv2.getTextSize(main_label, font, font_scale, thickness)
102
+
103
+ cv2.rectangle(annotated_frame,
104
+ (text_x - 5, text_y - text_h - 5),
105
+ (text_x + text_w + 5, text_y + 5),
106
+ (0, 0, 0), -1)
107
+
108
+ cv2.putText(annotated_frame, main_label,
109
+ (text_x, text_y),
110
+ font, font_scale,
111
+ (255, 255, 255), thickness)
112
+
113
+ # Draw keypoints
114
+ for idx, kp in enumerate(keypoints):
115
+ x, y, conf = int(kp[0]), int(kp[1]), kp[2]
116
+ if conf > conf_threshold:
117
+ draw_advanced_keypoint(
118
+ annotated_frame,
119
+ (x, y),
120
+ idx,
121
+ conf
122
+ )
123
+
124
+ # Convert back to RGB for Gradio
125
+ return cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
126
+
127
+ # Create Gradio interface
128
+ def create_interface():
129
+ examples = [
130
+ ["exm_1.jpg"],
131
+ ["exm_2.jpg"],
132
+ ["exm_3.jpg"],
133
+ ["exm_4.jpg"],
134
+ ]
135
+
136
+ iface = gr.Interface(
137
+ fn=process_image,
138
+ inputs=[
139
+ gr.Image(type="numpy", label="Input Image"),
140
+ ],
141
+ outputs=gr.Image(type="numpy", label="Detected Poses"),
142
+ title="Hockey Pose Estimation",
143
+ description="Upload an image of ice hockey players to detect pose keypoints.",
144
+ examples=examples,
145
+ theme=gr.themes.Base()
146
+ )
147
+ return iface
148
+
149
+ if __name__ == "__main__":
150
+ iface = create_interface()
151
+ iface.launch()