import gradio as gr import cv2 import numpy as np from ultralytics import YOLO import torch import os def create_gradient_circle(radius, color=(0, 255, 0), alpha=0.7): """Create a gradient circle with a glowing effect""" size = radius * 2 + 1 center = (radius, radius) circle_img = np.zeros((size, size, 4), dtype=np.uint8) for r in range(radius + 1): alpha_r = alpha * (1 - (r/radius)**2) cv2.circle(circle_img, center, r, (*color, int(255 * alpha_r)), -1) return circle_img def draw_advanced_keypoint(frame, center, keypoint_id, conf, radius=12): """Draw an advanced technical keypoint with class ID""" x, y = center color = (0, 255, 0) gradient = create_gradient_circle(radius + 4, color) gx1, gy1 = max(0, x-radius-4), max(0, y-radius-4) gx2, gy2 = min(frame.shape[1], x+radius+5), min(frame.shape[0], y+radius+5) if gx1 < gx2 and gy1 < gy2: roi = frame[gy1:gy2, gx1:gx2] gradient_roi = gradient[:gy2-gy1, :gx2-gx1] alpha = gradient_roi[:, :, 3:4] / 255.0 roi[:] = roi * (1 - alpha) + gradient_roi[:, :, :3] * alpha cv2.circle(frame, center, radius, color, 2) cv2.circle(frame, center, radius-2, color, -1) cv2.circle(frame, center, radius-1, (255, 255, 255), 1) label_text = f"{keypoint_id}:{conf:.2f}" font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 1.0 thickness = 2 (text_w, text_h), baseline = cv2.getTextSize(label_text, font, font_scale, thickness) margin = 2 bg_pts = np.array([ [x - text_w//2 - margin, y - radius - text_h - margin*2], [x + text_w//2 + margin, y - radius - text_h - margin*2], [x + text_w//2 + margin, y - radius - margin], [x + margin, y - radius + margin], [x - margin, y - radius + margin], [x - text_w//2 - margin, y - radius - margin], ], np.int32) cv2.fillPoly(frame, [bg_pts], (0, 0, 0)) cv2.polylines(frame, [bg_pts], True, color, 1) cv2.putText(frame, label_text, (x - text_w//2, y - radius - margin*2), font, font_scale, (255, 255, 255), thickness) def process_image(input_image, conf_threshold=0.5): """Process image for pose estimation""" # Load model model_path = "HockeyRink.pt" model = YOLO(model_path) # Convert Gradio image to CV2 format if necessary if isinstance(input_image, str): frame = cv2.imread(input_image) else: frame = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR) # Make prediction results = model.predict(frame, conf=conf_threshold) # Create copy for annotation annotated_frame = frame.copy() # Process each detection for result in results: if result.keypoints is not None: keypoints = result.keypoints.data[0] # Draw class label if hasattr(result, 'boxes') and len(result.boxes.cls) > 0: class_id = int(result.boxes.cls[0]) class_conf = float(result.boxes.conf[0]) if len(keypoints) > 0: text_x = int(min(kp[0] for kp in keypoints)) text_y = int(min(kp[1] for kp in keypoints)) - 40 main_label = f"Class ID:{class_id} ({class_conf:.2f})" font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 1.2 thickness = 2 (text_w, text_h), baseline = cv2.getTextSize(main_label, font, font_scale, thickness) cv2.rectangle(annotated_frame, (text_x - 5, text_y - text_h - 5), (text_x + text_w + 5, text_y + 5), (0, 0, 0), -1) cv2.putText(annotated_frame, main_label, (text_x, text_y), font, font_scale, (255, 255, 255), thickness) # Draw keypoints for idx, kp in enumerate(keypoints): x, y, conf = int(kp[0]), int(kp[1]), kp[2] if conf > conf_threshold: draw_advanced_keypoint( annotated_frame, (x, y), idx, conf ) # Convert back to RGB for Gradio return cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB) # Create Gradio interface def create_interface(): examples = [ ["exm_1.jpg"], ["exm_2.jpg"], ["exm_3.jpg"], ["exm_4.jpg"], ] iface = gr.Interface( fn=process_image, inputs=[ gr.Image(type="numpy", label="Input Image"), ], outputs=gr.Image(type="numpy", label="Detected Poses"), title="HockeyRink: A Model for Precise Ice Hockey Rink Keypoint Mapping and Analytics", description="Upload an image of ice hockey to detect keypoints on the rink.", examples=examples, theme=gr.themes.Base() ) return iface if __name__ == "__main__": iface = create_interface() iface.launch()