HockeyRink / app.py
MehdiH7's picture
Update app.py
ce3cabf verified
raw
history blame
5.32 kB
import gradio as gr
import cv2
import numpy as np
from ultralytics import YOLO
import torch
import os
def create_gradient_circle(radius, color=(0, 255, 0), alpha=0.7):
"""Create a gradient circle with a glowing effect"""
size = radius * 2 + 1
center = (radius, radius)
circle_img = np.zeros((size, size, 4), dtype=np.uint8)
for r in range(radius + 1):
alpha_r = alpha * (1 - (r/radius)**2)
cv2.circle(circle_img, center, r, (*color, int(255 * alpha_r)), -1)
return circle_img
def draw_advanced_keypoint(frame, center, keypoint_id, conf, radius=12):
"""Draw an advanced technical keypoint with class ID"""
x, y = center
color = (0, 255, 0)
gradient = create_gradient_circle(radius + 4, color)
gx1, gy1 = max(0, x-radius-4), max(0, y-radius-4)
gx2, gy2 = min(frame.shape[1], x+radius+5), min(frame.shape[0], y+radius+5)
if gx1 < gx2 and gy1 < gy2:
roi = frame[gy1:gy2, gx1:gx2]
gradient_roi = gradient[:gy2-gy1, :gx2-gx1]
alpha = gradient_roi[:, :, 3:4] / 255.0
roi[:] = roi * (1 - alpha) + gradient_roi[:, :, :3] * alpha
cv2.circle(frame, center, radius, color, 2)
cv2.circle(frame, center, radius-2, color, -1)
cv2.circle(frame, center, radius-1, (255, 255, 255), 1)
label_text = f"{keypoint_id}:{conf:.2f}"
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1.0
thickness = 2
(text_w, text_h), baseline = cv2.getTextSize(label_text, font, font_scale, thickness)
margin = 2
bg_pts = np.array([
[x - text_w//2 - margin, y - radius - text_h - margin*2],
[x + text_w//2 + margin, y - radius - text_h - margin*2],
[x + text_w//2 + margin, y - radius - margin],
[x + margin, y - radius + margin],
[x - margin, y - radius + margin],
[x - text_w//2 - margin, y - radius - margin],
], np.int32)
cv2.fillPoly(frame, [bg_pts], (0, 0, 0))
cv2.polylines(frame, [bg_pts], True, color, 1)
cv2.putText(frame, label_text,
(x - text_w//2, y - radius - margin*2),
font, font_scale, (255, 255, 255), thickness)
def process_image(input_image, conf_threshold=0.5):
"""Process image for pose estimation"""
# Load model
model_path = "HockeyRink.pt"
model = YOLO(model_path)
# Convert Gradio image to CV2 format if necessary
if isinstance(input_image, str):
frame = cv2.imread(input_image)
else:
frame = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)
# Make prediction
results = model.predict(frame, conf=conf_threshold)
# Create copy for annotation
annotated_frame = frame.copy()
# Process each detection
for result in results:
if result.keypoints is not None:
keypoints = result.keypoints.data[0]
# Draw class label
if hasattr(result, 'boxes') and len(result.boxes.cls) > 0:
class_id = int(result.boxes.cls[0])
class_conf = float(result.boxes.conf[0])
if len(keypoints) > 0:
text_x = int(min(kp[0] for kp in keypoints))
text_y = int(min(kp[1] for kp in keypoints)) - 40
main_label = f"Class ID:{class_id} ({class_conf:.2f})"
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1.2
thickness = 2
(text_w, text_h), baseline = cv2.getTextSize(main_label, font, font_scale, thickness)
cv2.rectangle(annotated_frame,
(text_x - 5, text_y - text_h - 5),
(text_x + text_w + 5, text_y + 5),
(0, 0, 0), -1)
cv2.putText(annotated_frame, main_label,
(text_x, text_y),
font, font_scale,
(255, 255, 255), thickness)
# Draw keypoints
for idx, kp in enumerate(keypoints):
x, y, conf = int(kp[0]), int(kp[1]), kp[2]
if conf > conf_threshold:
draw_advanced_keypoint(
annotated_frame,
(x, y),
idx,
conf
)
# Convert back to RGB for Gradio
return cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
# Create Gradio interface
def create_interface():
examples = [
["exm_1.jpg"],
["exm_2.jpg"],
["exm_3.jpg"],
["exm_4.jpg"],
]
iface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="numpy", label="Input Image"),
],
outputs=gr.Image(type="numpy", label="Detected Poses"),
title="HockeyRink: A Model for Precise Ice Hockey Rink Keypoint Mapping and Analytics",
description="Upload an image of ice hockey to detect keypoints on the rink.",
examples=examples,
theme=gr.themes.Base()
)
return iface
if __name__ == "__main__":
iface = create_interface()
iface.launch()