Gordon-H's picture
update public link
2c4555b verified
import gradio as gr
import onnxruntime as ort
import numpy as np
import pickle
import re
# Load the ONNX model
onnx_model_path = "formation_predictor.onnx"
ort_session = ort.InferenceSession(onnx_model_path)
# Function to convert input data to one-hot encoding
def to_one_hot(indices, num_classes):
indices = np.array(indices, dtype=int)
return np.eye(num_classes)[indices]
# Load the label encoder
def load_label_encoder():
with open("label_encoder.pkl", "rb") as f:
le = pickle.load(f)
return le
le = load_label_encoder()
num_classes = len(le.classes_)
# Function to prepare input data
def prepare_input(opponent_formation, le, num_classes):
opponent_formation = opponent_formation.strip().strip("'\"[]") # Ensure no leading/trailing spaces, quotes, or brackets
opp_idx = le.transform([opponent_formation])[0] if isinstance(opponent_formation, str) else opponent_formation
opp_one_hot = to_one_hot([opp_idx], num_classes)
return opp_one_hot
# Function to recommend formation using ONNX model
def recommend_formation_onnx(opponent_formation, ort_session, le, num_classes):
opp_one_hot = prepare_input(opponent_formation, le, num_classes)
best_formation, best_score = None, -float("inf")
evaluated_formations = []
for our_idx in range(num_classes):
our_one_hot = to_one_hot([our_idx], num_classes)
input_vector = np.concatenate([opp_one_hot, our_one_hot], axis=1).astype(np.float32)
# Run the ONNX model
ort_inputs = {ort_session.get_inputs()[0].name: input_vector}
ort_outs = ort_session.run(None, ort_inputs)
score = ort_outs[0][0, 0]
formation = le.inverse_transform([our_idx])[0]
evaluated_formations.append((formation, score))
if score > best_score:
best_score = score
best_formation = formation
evaluated_formations.sort(key=lambda x: x[1], reverse=True)
return best_formation, evaluated_formations
# Function to handle the recommend button click
def recommend(opponent_formation):
opponent_formation = opponent_formation.strip().strip("'\"[]") # Ensure no leading/trailing spaces, quotes, or brackets
# Validate the format of the opponent formation
if not re.match(r'^\d+(-\d+)+$', opponent_formation):
return f"Error: Formation '{opponent_formation}' is not in the correct format (e.g., '3-4-2-1').", []
if opponent_formation not in le.classes_:
return f"Error: Formation '{opponent_formation}' not recognized.", []
best_formation, evaluated_formations = recommend_formation_onnx(opponent_formation, ort_session, le, num_classes)
return f"Recommended formation: {best_formation}", evaluated_formations
# Create the Gradio interface
iface = gr.Interface(
fn=recommend,
inputs=gr.Textbox(lines=1, placeholder="Enter opponent formation (e.g., '3-4-2-1')"),
outputs=[
gr.Textbox(label="Recommended Formation"),
gr.Dataframe(headers=["Formation", "Score"], label="Evaluated Formations")
],
title="Deepfield Proyecto Maradona E3 Football Formation Recommender",
description="Enter the opponent formation to get the recommended formation and a list of evaluated formations with their scores."
)
# Launch the Gradio interface
iface.launch(share=True)