Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from ultralytics import YOLO
|
5 |
+
import torch
|
6 |
+
import os
|
7 |
+
|
8 |
+
def create_gradient_circle(radius, color=(0, 255, 0), alpha=0.7):
|
9 |
+
"""Create a gradient circle with a glowing effect"""
|
10 |
+
size = radius * 2 + 1
|
11 |
+
center = (radius, radius)
|
12 |
+
circle_img = np.zeros((size, size, 4), dtype=np.uint8)
|
13 |
+
|
14 |
+
for r in range(radius + 1):
|
15 |
+
alpha_r = alpha * (1 - (r/radius)**2)
|
16 |
+
cv2.circle(circle_img, center, r, (*color, int(255 * alpha_r)), -1)
|
17 |
+
|
18 |
+
return circle_img
|
19 |
+
|
20 |
+
def draw_advanced_keypoint(frame, center, keypoint_id, conf, radius=12):
|
21 |
+
"""Draw an advanced technical keypoint with class ID"""
|
22 |
+
x, y = center
|
23 |
+
color = (0, 255, 0)
|
24 |
+
|
25 |
+
gradient = create_gradient_circle(radius + 4, color)
|
26 |
+
|
27 |
+
gx1, gy1 = max(0, x-radius-4), max(0, y-radius-4)
|
28 |
+
gx2, gy2 = min(frame.shape[1], x+radius+5), min(frame.shape[0], y+radius+5)
|
29 |
+
|
30 |
+
if gx1 < gx2 and gy1 < gy2:
|
31 |
+
roi = frame[gy1:gy2, gx1:gx2]
|
32 |
+
gradient_roi = gradient[:gy2-gy1, :gx2-gx1]
|
33 |
+
alpha = gradient_roi[:, :, 3:4] / 255.0
|
34 |
+
roi[:] = roi * (1 - alpha) + gradient_roi[:, :, :3] * alpha
|
35 |
+
|
36 |
+
cv2.circle(frame, center, radius, color, 2)
|
37 |
+
cv2.circle(frame, center, radius-2, color, -1)
|
38 |
+
cv2.circle(frame, center, radius-1, (255, 255, 255), 1)
|
39 |
+
|
40 |
+
label_text = f"{keypoint_id}:{conf:.2f}"
|
41 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
42 |
+
font_scale = 1.0
|
43 |
+
thickness = 2
|
44 |
+
|
45 |
+
(text_w, text_h), baseline = cv2.getTextSize(label_text, font, font_scale, thickness)
|
46 |
+
|
47 |
+
margin = 2
|
48 |
+
bg_pts = np.array([
|
49 |
+
[x - text_w//2 - margin, y - radius - text_h - margin*2],
|
50 |
+
[x + text_w//2 + margin, y - radius - text_h - margin*2],
|
51 |
+
[x + text_w//2 + margin, y - radius - margin],
|
52 |
+
[x + margin, y - radius + margin],
|
53 |
+
[x - margin, y - radius + margin],
|
54 |
+
[x - text_w//2 - margin, y - radius - margin],
|
55 |
+
], np.int32)
|
56 |
+
|
57 |
+
cv2.fillPoly(frame, [bg_pts], (0, 0, 0))
|
58 |
+
cv2.polylines(frame, [bg_pts], True, color, 1)
|
59 |
+
|
60 |
+
cv2.putText(frame, label_text,
|
61 |
+
(x - text_w//2, y - radius - margin*2),
|
62 |
+
font, font_scale, (255, 255, 255), thickness)
|
63 |
+
|
64 |
+
def process_image(input_image, conf_threshold=0.5):
|
65 |
+
"""Process image for pose estimation"""
|
66 |
+
# Load model
|
67 |
+
model_path = "HockeyRink.pt"
|
68 |
+
model = YOLO(model_path)
|
69 |
+
|
70 |
+
# Convert Gradio image to CV2 format if necessary
|
71 |
+
if isinstance(input_image, str):
|
72 |
+
frame = cv2.imread(input_image)
|
73 |
+
else:
|
74 |
+
frame = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)
|
75 |
+
|
76 |
+
# Make prediction
|
77 |
+
results = model.predict(frame, conf=conf_threshold)
|
78 |
+
|
79 |
+
# Create copy for annotation
|
80 |
+
annotated_frame = frame.copy()
|
81 |
+
|
82 |
+
# Process each detection
|
83 |
+
for result in results:
|
84 |
+
if result.keypoints is not None:
|
85 |
+
keypoints = result.keypoints.data[0]
|
86 |
+
|
87 |
+
# Draw class label
|
88 |
+
if hasattr(result, 'boxes') and len(result.boxes.cls) > 0:
|
89 |
+
class_id = int(result.boxes.cls[0])
|
90 |
+
class_conf = float(result.boxes.conf[0])
|
91 |
+
|
92 |
+
if len(keypoints) > 0:
|
93 |
+
text_x = int(min(kp[0] for kp in keypoints))
|
94 |
+
text_y = int(min(kp[1] for kp in keypoints)) - 40
|
95 |
+
|
96 |
+
main_label = f"Class ID:{class_id} ({class_conf:.2f})"
|
97 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
98 |
+
font_scale = 1.2
|
99 |
+
thickness = 2
|
100 |
+
|
101 |
+
(text_w, text_h), baseline = cv2.getTextSize(main_label, font, font_scale, thickness)
|
102 |
+
|
103 |
+
cv2.rectangle(annotated_frame,
|
104 |
+
(text_x - 5, text_y - text_h - 5),
|
105 |
+
(text_x + text_w + 5, text_y + 5),
|
106 |
+
(0, 0, 0), -1)
|
107 |
+
|
108 |
+
cv2.putText(annotated_frame, main_label,
|
109 |
+
(text_x, text_y),
|
110 |
+
font, font_scale,
|
111 |
+
(255, 255, 255), thickness)
|
112 |
+
|
113 |
+
# Draw keypoints
|
114 |
+
for idx, kp in enumerate(keypoints):
|
115 |
+
x, y, conf = int(kp[0]), int(kp[1]), kp[2]
|
116 |
+
if conf > conf_threshold:
|
117 |
+
draw_advanced_keypoint(
|
118 |
+
annotated_frame,
|
119 |
+
(x, y),
|
120 |
+
idx,
|
121 |
+
conf
|
122 |
+
)
|
123 |
+
|
124 |
+
# Convert back to RGB for Gradio
|
125 |
+
return cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
|
126 |
+
|
127 |
+
# Create Gradio interface
|
128 |
+
def create_interface():
|
129 |
+
examples = [
|
130 |
+
["exm_1.jpg"],
|
131 |
+
["exm_2.jpg"],
|
132 |
+
["exm_3.jpg"],
|
133 |
+
["exm_4.jpg"],
|
134 |
+
]
|
135 |
+
|
136 |
+
iface = gr.Interface(
|
137 |
+
fn=process_image,
|
138 |
+
inputs=[
|
139 |
+
gr.Image(type="numpy", label="Input Image"),
|
140 |
+
],
|
141 |
+
outputs=gr.Image(type="numpy", label="Detected Poses"),
|
142 |
+
title="Hockey Pose Estimation",
|
143 |
+
description="Upload an image of ice hockey players to detect pose keypoints.",
|
144 |
+
examples=examples,
|
145 |
+
theme=gr.themes.Base()
|
146 |
+
)
|
147 |
+
return iface
|
148 |
+
|
149 |
+
if __name__ == "__main__":
|
150 |
+
iface = create_interface()
|
151 |
+
iface.launch()
|