|
import gradio as gr |
|
import onnxruntime as ort |
|
import numpy as np |
|
import pickle |
|
import re |
|
|
|
|
|
onnx_model_path = "formation_predictor.onnx" |
|
ort_session = ort.InferenceSession(onnx_model_path) |
|
|
|
|
|
def to_one_hot(indices, num_classes): |
|
indices = np.array(indices, dtype=int) |
|
return np.eye(num_classes)[indices] |
|
|
|
|
|
def load_label_encoder(): |
|
with open("label_encoder.pkl", "rb") as f: |
|
le = pickle.load(f) |
|
return le |
|
|
|
le = load_label_encoder() |
|
num_classes = len(le.classes_) |
|
|
|
|
|
def prepare_input(opponent_formation, le, num_classes): |
|
opponent_formation = opponent_formation.strip().strip("'\"[]") |
|
opp_idx = le.transform([opponent_formation])[0] if isinstance(opponent_formation, str) else opponent_formation |
|
opp_one_hot = to_one_hot([opp_idx], num_classes) |
|
return opp_one_hot |
|
|
|
|
|
def recommend_formation_onnx(opponent_formation, ort_session, le, num_classes): |
|
opp_one_hot = prepare_input(opponent_formation, le, num_classes) |
|
|
|
best_formation, best_score = None, -float("inf") |
|
evaluated_formations = [] |
|
for our_idx in range(num_classes): |
|
our_one_hot = to_one_hot([our_idx], num_classes) |
|
input_vector = np.concatenate([opp_one_hot, our_one_hot], axis=1).astype(np.float32) |
|
|
|
|
|
ort_inputs = {ort_session.get_inputs()[0].name: input_vector} |
|
ort_outs = ort_session.run(None, ort_inputs) |
|
score = ort_outs[0][0, 0] |
|
|
|
formation = le.inverse_transform([our_idx])[0] |
|
evaluated_formations.append((formation, score)) |
|
|
|
if score > best_score: |
|
best_score = score |
|
best_formation = formation |
|
|
|
evaluated_formations.sort(key=lambda x: x[1], reverse=True) |
|
return best_formation, evaluated_formations |
|
|
|
|
|
def recommend(opponent_formation): |
|
opponent_formation = opponent_formation.strip().strip("'\"[]") |
|
|
|
|
|
if not re.match(r'^\d+(-\d+)+$', opponent_formation): |
|
return f"Error: Formation '{opponent_formation}' is not in the correct format (e.g., '3-4-2-1').", [] |
|
|
|
if opponent_formation not in le.classes_: |
|
return f"Error: Formation '{opponent_formation}' not recognized.", [] |
|
|
|
best_formation, evaluated_formations = recommend_formation_onnx(opponent_formation, ort_session, le, num_classes) |
|
return f"Recommended formation: {best_formation}", evaluated_formations |
|
|
|
|
|
iface = gr.Interface( |
|
fn=recommend, |
|
inputs=gr.Textbox(lines=1, placeholder="Enter opponent formation (e.g., '3-4-2-1')"), |
|
outputs=[ |
|
gr.Textbox(label="Recommended Formation"), |
|
gr.Dataframe(headers=["Formation", "Score"], label="Evaluated Formations") |
|
], |
|
title="Deepfield Proyecto Maradona E3 Football Formation Recommender", |
|
description="Enter the opponent formation to get the recommended formation and a list of evaluated formations with their scores." |
|
) |
|
|
|
|
|
iface.launch(share=True) |