import gradio as gr import onnxruntime as ort import numpy as np import pickle import re # Load the ONNX model onnx_model_path = "formation_predictor.onnx" ort_session = ort.InferenceSession(onnx_model_path) # Function to convert input data to one-hot encoding def to_one_hot(indices, num_classes): indices = np.array(indices, dtype=int) return np.eye(num_classes)[indices] # Load the label encoder def load_label_encoder(): with open("label_encoder.pkl", "rb") as f: le = pickle.load(f) return le le = load_label_encoder() num_classes = len(le.classes_) # Function to prepare input data def prepare_input(opponent_formation, le, num_classes): opponent_formation = opponent_formation.strip().strip("'\"[]") # Ensure no leading/trailing spaces, quotes, or brackets opp_idx = le.transform([opponent_formation])[0] if isinstance(opponent_formation, str) else opponent_formation opp_one_hot = to_one_hot([opp_idx], num_classes) return opp_one_hot # Function to recommend formation using ONNX model def recommend_formation_onnx(opponent_formation, ort_session, le, num_classes): opp_one_hot = prepare_input(opponent_formation, le, num_classes) best_formation, best_score = None, -float("inf") evaluated_formations = [] for our_idx in range(num_classes): our_one_hot = to_one_hot([our_idx], num_classes) input_vector = np.concatenate([opp_one_hot, our_one_hot], axis=1).astype(np.float32) # Run the ONNX model ort_inputs = {ort_session.get_inputs()[0].name: input_vector} ort_outs = ort_session.run(None, ort_inputs) score = ort_outs[0][0, 0] formation = le.inverse_transform([our_idx])[0] evaluated_formations.append((formation, score)) if score > best_score: best_score = score best_formation = formation evaluated_formations.sort(key=lambda x: x[1], reverse=True) return best_formation, evaluated_formations # Function to handle the recommend button click def recommend(opponent_formation): opponent_formation = opponent_formation.strip().strip("'\"[]") # Ensure no leading/trailing spaces, quotes, or brackets # Validate the format of the opponent formation if not re.match(r'^\d+(-\d+)+$', opponent_formation): return f"Error: Formation '{opponent_formation}' is not in the correct format (e.g., '3-4-2-1').", [] if opponent_formation not in le.classes_: return f"Error: Formation '{opponent_formation}' not recognized.", [] best_formation, evaluated_formations = recommend_formation_onnx(opponent_formation, ort_session, le, num_classes) return f"Recommended formation: {best_formation}", evaluated_formations # Create the Gradio interface iface = gr.Interface( fn=recommend, inputs=gr.Textbox(lines=1, placeholder="Enter opponent formation (e.g., '3-4-2-1')"), outputs=[ gr.Textbox(label="Recommended Formation"), gr.Dataframe(headers=["Formation", "Score"], label="Evaluated Formations") ], title="Deepfield Proyecto Maradona E3 Football Formation Recommender", description="Enter the opponent formation to get the recommended formation and a list of evaluated formations with their scores." ) # Launch the Gradio interface iface.launch(share=True)