HockeyOrient / app.py
MehdiH7's picture
Update app.py
2595a2a verified
# app.py
import gradio as gr
import torch
from torchvision import transforms, models
import cv2
import numpy as np
from PIL import Image
from ultralytics import YOLO
def load_models():
# Initialize YOLO
yolo_model = YOLO('HockeyAI.pt')
# Initialize SqueezeNet
squeezenet_model = models.squeezenet1_1(weights=None)
squeezenet_model.classifier[1] = torch.nn.Conv2d(512, 8, kernel_size=1)
squeezenet_model.num_classes = 8
squeezenet_model.load_state_dict(torch.load('best_model_squezenet.pth', map_location=torch.device('cpu')))
squeezenet_model.eval()
return yolo_model, squeezenet_model
def process_image(input_image):
if input_image is None:
return None
# Convert to numpy array if needed
if isinstance(input_image, str):
image = cv2.imread(input_image)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
else:
image = input_image.copy()
# Initialize models
yolo_model, squeezenet_model = load_models()
# Class labels for direction
class_labels = [
"Bottom", "Bottom_Left", "Bottom_Right", "Left",
"Right", "Top", "Top_Left", "Top_Right"
]
# Transform for SqueezeNet
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Run YOLO detection
results = yolo_model(image)
# Process each detection
for box in results[0].boxes:
xyxy = box.xyxy[0].cpu().numpy()
conf = float(box.conf[0].cpu().numpy())
cls = int(box.cls[0].cpu().numpy())
# Process only if it's a player (class 4) and confidence is above threshold
if cls == 4 and conf > 0.5:
x1, y1, x2, y2 = map(int, xyxy)
# Crop and process for direction classification
if x2 > x1 and y2 > y1:
cropped_array = image[y1:y2, x1:x2]
if cropped_array.size > 0:
cropped_image = Image.fromarray(cropped_array)
# Predict direction
image_tensor = transform(cropped_image).unsqueeze(0)
with torch.no_grad():
output = squeezenet_model(image_tensor)
direction_class = torch.argmax(output, dim=1).item()
# Draw annotations
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(image, f"{conf:.2f}", (x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
# Draw direction arrow
center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
arrow_length = 80 # Increased from 50 to 80
direction = class_labels[direction_class]
# Calculate arrow endpoint
end_x, end_y = center_x, center_y
if "Top" in direction:
end_y = center_y - arrow_length
elif "Bottom" in direction:
end_y = center_y + arrow_length
if "Left" in direction:
end_x = center_x - arrow_length
elif "Right" in direction:
end_x = center_x + arrow_length
cv2.arrowedLine(image, (center_x, center_y), (end_x, end_y),
(255, 0, 0), 4, tipLength=0.4)
return image
# Create Gradio interface
def gradio_interface():
with gr.Blocks() as iface:
gr.Markdown("# Player Direction Detection")
gr.Markdown("Upload an image to detect players and their movement directions")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="numpy")
with gr.Column():
output_image = gr.Image(label="Output Image")
# Handle image processing
input_image.change(
fn=process_image,
inputs=[input_image],
outputs=[output_image]
)
# Add example images if you have them
gr.Examples(
examples=["example-1.jpg", "example-2.jpg"],
inputs=input_image,
outputs=output_image,
fn=process_image,
cache_examples=True
)
return iface
if __name__ == "__main__":
iface = gradio_interface()
iface.launch()