Upload 4 files
Browse files- app.py +84 -0
- formation_predictor.onnx +3 -0
- label_encoder.pkl +3 -0
- requirements.txt +4 -0
app.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import onnxruntime as ort
|
3 |
+
import numpy as np
|
4 |
+
import pickle
|
5 |
+
import re
|
6 |
+
|
7 |
+
# Load the ONNX model
|
8 |
+
onnx_model_path = "formation_predictor.onnx"
|
9 |
+
ort_session = ort.InferenceSession(onnx_model_path)
|
10 |
+
|
11 |
+
# Function to convert input data to one-hot encoding
|
12 |
+
def to_one_hot(indices, num_classes):
|
13 |
+
indices = np.array(indices, dtype=int)
|
14 |
+
return np.eye(num_classes)[indices]
|
15 |
+
|
16 |
+
# Load the label encoder
|
17 |
+
def load_label_encoder():
|
18 |
+
with open("label_encoder.pkl", "rb") as f:
|
19 |
+
le = pickle.load(f)
|
20 |
+
return le
|
21 |
+
|
22 |
+
le = load_label_encoder()
|
23 |
+
num_classes = len(le.classes_)
|
24 |
+
|
25 |
+
# Function to prepare input data
|
26 |
+
def prepare_input(opponent_formation, le, num_classes):
|
27 |
+
opponent_formation = opponent_formation.strip().strip("'\"[]") # Ensure no leading/trailing spaces, quotes, or brackets
|
28 |
+
opp_idx = le.transform([opponent_formation])[0] if isinstance(opponent_formation, str) else opponent_formation
|
29 |
+
opp_one_hot = to_one_hot([opp_idx], num_classes)
|
30 |
+
return opp_one_hot
|
31 |
+
|
32 |
+
# Function to recommend formation using ONNX model
|
33 |
+
def recommend_formation_onnx(opponent_formation, ort_session, le, num_classes):
|
34 |
+
opp_one_hot = prepare_input(opponent_formation, le, num_classes)
|
35 |
+
|
36 |
+
best_formation, best_score = None, -float("inf")
|
37 |
+
evaluated_formations = []
|
38 |
+
for our_idx in range(num_classes):
|
39 |
+
our_one_hot = to_one_hot([our_idx], num_classes)
|
40 |
+
input_vector = np.concatenate([opp_one_hot, our_one_hot], axis=1).astype(np.float32)
|
41 |
+
|
42 |
+
# Run the ONNX model
|
43 |
+
ort_inputs = {ort_session.get_inputs()[0].name: input_vector}
|
44 |
+
ort_outs = ort_session.run(None, ort_inputs)
|
45 |
+
score = ort_outs[0][0, 0]
|
46 |
+
|
47 |
+
formation = le.inverse_transform([our_idx])[0]
|
48 |
+
evaluated_formations.append((formation, score))
|
49 |
+
|
50 |
+
if score > best_score:
|
51 |
+
best_score = score
|
52 |
+
best_formation = formation
|
53 |
+
|
54 |
+
evaluated_formations.sort(key=lambda x: x[1], reverse=True)
|
55 |
+
return best_formation, evaluated_formations
|
56 |
+
|
57 |
+
# Function to handle the recommend button click
|
58 |
+
def recommend(opponent_formation):
|
59 |
+
opponent_formation = opponent_formation.strip().strip("'\"[]") # Ensure no leading/trailing spaces, quotes, or brackets
|
60 |
+
|
61 |
+
# Validate the format of the opponent formation
|
62 |
+
if not re.match(r'^\d+(-\d+)+$', opponent_formation):
|
63 |
+
return f"Error: Formation '{opponent_formation}' is not in the correct format (e.g., '3-4-2-1').", []
|
64 |
+
|
65 |
+
if opponent_formation not in le.classes_:
|
66 |
+
return f"Error: Formation '{opponent_formation}' not recognized.", []
|
67 |
+
|
68 |
+
best_formation, evaluated_formations = recommend_formation_onnx(opponent_formation, ort_session, le, num_classes)
|
69 |
+
return f"Recommended formation: {best_formation}", evaluated_formations
|
70 |
+
|
71 |
+
# Create the Gradio interface
|
72 |
+
iface = gr.Interface(
|
73 |
+
fn=recommend,
|
74 |
+
inputs=gr.inputs.Textbox(lines=1, placeholder="Enter opponent formation (e.g., '3-4-2-1')"),
|
75 |
+
outputs=[
|
76 |
+
gr.outputs.Textbox(label="Recommended Formation"),
|
77 |
+
gr.outputs.Dataframe(headers=["Formation", "Score"], label="Evaluated Formations")
|
78 |
+
],
|
79 |
+
title="Deepfield Proyecto Maradona E3 Football Formation Recommender",
|
80 |
+
description="Enter the opponent formation to get the recommended formation and a list of evaluated formations with their scores."
|
81 |
+
)
|
82 |
+
|
83 |
+
# Launch the Gradio interface
|
84 |
+
iface.launch()
|
formation_predictor.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1de5260c6c8457b12da0c2930bc5565b55099bdad0831cab031a690bd8244d94
|
3 |
+
size 18124
|
label_encoder.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c26c641ba961031566abddac21718a0bd1e01e1e1e31022a08a81b594c7095fd
|
3 |
+
size 402
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
onnxruntime
|
3 |
+
numpy
|
4 |
+
scikit-learn
|