Spaces:
Running
Running
import gradio as gr | |
import cv2 | |
import numpy as np | |
from ultralytics import YOLO | |
import torch | |
import os | |
def create_gradient_circle(radius, color=(0, 255, 0), alpha=0.7): | |
"""Create a gradient circle with a glowing effect""" | |
size = radius * 2 + 1 | |
center = (radius, radius) | |
circle_img = np.zeros((size, size, 4), dtype=np.uint8) | |
for r in range(radius + 1): | |
alpha_r = alpha * (1 - (r/radius)**2) | |
cv2.circle(circle_img, center, r, (*color, int(255 * alpha_r)), -1) | |
return circle_img | |
def draw_advanced_keypoint(frame, center, keypoint_id, conf, radius=12): | |
"""Draw an advanced technical keypoint with class ID""" | |
x, y = center | |
color = (0, 255, 0) | |
gradient = create_gradient_circle(radius + 4, color) | |
gx1, gy1 = max(0, x-radius-4), max(0, y-radius-4) | |
gx2, gy2 = min(frame.shape[1], x+radius+5), min(frame.shape[0], y+radius+5) | |
if gx1 < gx2 and gy1 < gy2: | |
roi = frame[gy1:gy2, gx1:gx2] | |
gradient_roi = gradient[:gy2-gy1, :gx2-gx1] | |
alpha = gradient_roi[:, :, 3:4] / 255.0 | |
roi[:] = roi * (1 - alpha) + gradient_roi[:, :, :3] * alpha | |
cv2.circle(frame, center, radius, color, 2) | |
cv2.circle(frame, center, radius-2, color, -1) | |
cv2.circle(frame, center, radius-1, (255, 255, 255), 1) | |
label_text = f"{keypoint_id}:{conf:.2f}" | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
font_scale = 1.0 | |
thickness = 2 | |
(text_w, text_h), baseline = cv2.getTextSize(label_text, font, font_scale, thickness) | |
margin = 2 | |
bg_pts = np.array([ | |
[x - text_w//2 - margin, y - radius - text_h - margin*2], | |
[x + text_w//2 + margin, y - radius - text_h - margin*2], | |
[x + text_w//2 + margin, y - radius - margin], | |
[x + margin, y - radius + margin], | |
[x - margin, y - radius + margin], | |
[x - text_w//2 - margin, y - radius - margin], | |
], np.int32) | |
cv2.fillPoly(frame, [bg_pts], (0, 0, 0)) | |
cv2.polylines(frame, [bg_pts], True, color, 1) | |
cv2.putText(frame, label_text, | |
(x - text_w//2, y - radius - margin*2), | |
font, font_scale, (255, 255, 255), thickness) | |
def process_image(input_image, conf_threshold=0.5): | |
"""Process image for pose estimation""" | |
# Load model | |
model_path = "HockeyRink.pt" | |
model = YOLO(model_path) | |
# Convert Gradio image to CV2 format if necessary | |
if isinstance(input_image, str): | |
frame = cv2.imread(input_image) | |
else: | |
frame = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR) | |
# Make prediction | |
results = model.predict(frame, conf=conf_threshold) | |
# Create copy for annotation | |
annotated_frame = frame.copy() | |
# Process each detection | |
for result in results: | |
if result.keypoints is not None: | |
keypoints = result.keypoints.data[0] | |
# Draw class label | |
if hasattr(result, 'boxes') and len(result.boxes.cls) > 0: | |
class_id = int(result.boxes.cls[0]) | |
class_conf = float(result.boxes.conf[0]) | |
if len(keypoints) > 0: | |
text_x = int(min(kp[0] for kp in keypoints)) | |
text_y = int(min(kp[1] for kp in keypoints)) - 40 | |
main_label = f"Class ID:{class_id} ({class_conf:.2f})" | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
font_scale = 1.2 | |
thickness = 2 | |
(text_w, text_h), baseline = cv2.getTextSize(main_label, font, font_scale, thickness) | |
cv2.rectangle(annotated_frame, | |
(text_x - 5, text_y - text_h - 5), | |
(text_x + text_w + 5, text_y + 5), | |
(0, 0, 0), -1) | |
cv2.putText(annotated_frame, main_label, | |
(text_x, text_y), | |
font, font_scale, | |
(255, 255, 255), thickness) | |
# Draw keypoints | |
for idx, kp in enumerate(keypoints): | |
x, y, conf = int(kp[0]), int(kp[1]), kp[2] | |
if conf > conf_threshold: | |
draw_advanced_keypoint( | |
annotated_frame, | |
(x, y), | |
idx, | |
conf | |
) | |
# Convert back to RGB for Gradio | |
return cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB) | |
# Create Gradio interface | |
def create_interface(): | |
examples = [ | |
["exm_1.jpg"], | |
["exm_2.jpg"], | |
["exm_3.jpg"], | |
["exm_4.jpg"], | |
] | |
iface = gr.Interface( | |
fn=process_image, | |
inputs=[ | |
gr.Image(type="numpy", label="Input Image"), | |
], | |
outputs=gr.Image(type="numpy", label="Detected Poses"), | |
title="HockeyRink: A Model for Precise Ice Hockey Rink Keypoint Mapping and Analytics", | |
description="Upload an image of ice hockey to detect keypoints on the rink.", | |
examples=examples, | |
theme=gr.themes.Base() | |
) | |
return iface | |
if __name__ == "__main__": | |
iface = create_interface() | |
iface.launch() |