Gordon-H commited on
Commit
ed67c52
·
verified ·
1 Parent(s): a69ca74

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +84 -0
  2. formation_predictor.onnx +3 -0
  3. label_encoder.pkl +3 -0
  4. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import onnxruntime as ort
3
+ import numpy as np
4
+ import pickle
5
+ import re
6
+
7
+ # Load the ONNX model
8
+ onnx_model_path = "formation_predictor.onnx"
9
+ ort_session = ort.InferenceSession(onnx_model_path)
10
+
11
+ # Function to convert input data to one-hot encoding
12
+ def to_one_hot(indices, num_classes):
13
+ indices = np.array(indices, dtype=int)
14
+ return np.eye(num_classes)[indices]
15
+
16
+ # Load the label encoder
17
+ def load_label_encoder():
18
+ with open("label_encoder.pkl", "rb") as f:
19
+ le = pickle.load(f)
20
+ return le
21
+
22
+ le = load_label_encoder()
23
+ num_classes = len(le.classes_)
24
+
25
+ # Function to prepare input data
26
+ def prepare_input(opponent_formation, le, num_classes):
27
+ opponent_formation = opponent_formation.strip().strip("'\"[]") # Ensure no leading/trailing spaces, quotes, or brackets
28
+ opp_idx = le.transform([opponent_formation])[0] if isinstance(opponent_formation, str) else opponent_formation
29
+ opp_one_hot = to_one_hot([opp_idx], num_classes)
30
+ return opp_one_hot
31
+
32
+ # Function to recommend formation using ONNX model
33
+ def recommend_formation_onnx(opponent_formation, ort_session, le, num_classes):
34
+ opp_one_hot = prepare_input(opponent_formation, le, num_classes)
35
+
36
+ best_formation, best_score = None, -float("inf")
37
+ evaluated_formations = []
38
+ for our_idx in range(num_classes):
39
+ our_one_hot = to_one_hot([our_idx], num_classes)
40
+ input_vector = np.concatenate([opp_one_hot, our_one_hot], axis=1).astype(np.float32)
41
+
42
+ # Run the ONNX model
43
+ ort_inputs = {ort_session.get_inputs()[0].name: input_vector}
44
+ ort_outs = ort_session.run(None, ort_inputs)
45
+ score = ort_outs[0][0, 0]
46
+
47
+ formation = le.inverse_transform([our_idx])[0]
48
+ evaluated_formations.append((formation, score))
49
+
50
+ if score > best_score:
51
+ best_score = score
52
+ best_formation = formation
53
+
54
+ evaluated_formations.sort(key=lambda x: x[1], reverse=True)
55
+ return best_formation, evaluated_formations
56
+
57
+ # Function to handle the recommend button click
58
+ def recommend(opponent_formation):
59
+ opponent_formation = opponent_formation.strip().strip("'\"[]") # Ensure no leading/trailing spaces, quotes, or brackets
60
+
61
+ # Validate the format of the opponent formation
62
+ if not re.match(r'^\d+(-\d+)+$', opponent_formation):
63
+ return f"Error: Formation '{opponent_formation}' is not in the correct format (e.g., '3-4-2-1').", []
64
+
65
+ if opponent_formation not in le.classes_:
66
+ return f"Error: Formation '{opponent_formation}' not recognized.", []
67
+
68
+ best_formation, evaluated_formations = recommend_formation_onnx(opponent_formation, ort_session, le, num_classes)
69
+ return f"Recommended formation: {best_formation}", evaluated_formations
70
+
71
+ # Create the Gradio interface
72
+ iface = gr.Interface(
73
+ fn=recommend,
74
+ inputs=gr.inputs.Textbox(lines=1, placeholder="Enter opponent formation (e.g., '3-4-2-1')"),
75
+ outputs=[
76
+ gr.outputs.Textbox(label="Recommended Formation"),
77
+ gr.outputs.Dataframe(headers=["Formation", "Score"], label="Evaluated Formations")
78
+ ],
79
+ title="Deepfield Proyecto Maradona E3 Football Formation Recommender",
80
+ description="Enter the opponent formation to get the recommended formation and a list of evaluated formations with their scores."
81
+ )
82
+
83
+ # Launch the Gradio interface
84
+ iface.launch()
formation_predictor.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1de5260c6c8457b12da0c2930bc5565b55099bdad0831cab031a690bd8244d94
3
+ size 18124
label_encoder.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c26c641ba961031566abddac21718a0bd1e01e1e1e31022a08a81b594c7095fd
3
+ size 402
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ onnxruntime
3
+ numpy
4
+ scikit-learn