File size: 3,353 Bytes
ed67c52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f55c607
ed67c52
f55c607
 
ed67c52
 
 
 
 
 
2c4555b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
import onnxruntime as ort
import numpy as np
import pickle
import re

# Load the ONNX model
onnx_model_path = "formation_predictor.onnx"
ort_session = ort.InferenceSession(onnx_model_path)

# Function to convert input data to one-hot encoding
def to_one_hot(indices, num_classes):
    indices = np.array(indices, dtype=int)
    return np.eye(num_classes)[indices]

# Load the label encoder
def load_label_encoder():
    with open("label_encoder.pkl", "rb") as f:
        le = pickle.load(f)
    return le

le = load_label_encoder()
num_classes = len(le.classes_)

# Function to prepare input data
def prepare_input(opponent_formation, le, num_classes):
    opponent_formation = opponent_formation.strip().strip("'\"[]")  # Ensure no leading/trailing spaces, quotes, or brackets
    opp_idx = le.transform([opponent_formation])[0] if isinstance(opponent_formation, str) else opponent_formation
    opp_one_hot = to_one_hot([opp_idx], num_classes)
    return opp_one_hot

# Function to recommend formation using ONNX model
def recommend_formation_onnx(opponent_formation, ort_session, le, num_classes):
    opp_one_hot = prepare_input(opponent_formation, le, num_classes)
    
    best_formation, best_score = None, -float("inf")
    evaluated_formations = []
    for our_idx in range(num_classes):
        our_one_hot = to_one_hot([our_idx], num_classes)
        input_vector = np.concatenate([opp_one_hot, our_one_hot], axis=1).astype(np.float32)
        
        # Run the ONNX model
        ort_inputs = {ort_session.get_inputs()[0].name: input_vector}
        ort_outs = ort_session.run(None, ort_inputs)
        score = ort_outs[0][0, 0]
        
        formation = le.inverse_transform([our_idx])[0]
        evaluated_formations.append((formation, score))
        
        if score > best_score:
            best_score = score
            best_formation = formation
    
    evaluated_formations.sort(key=lambda x: x[1], reverse=True)
    return best_formation, evaluated_formations

# Function to handle the recommend button click
def recommend(opponent_formation):
    opponent_formation = opponent_formation.strip().strip("'\"[]")  # Ensure no leading/trailing spaces, quotes, or brackets
    
    # Validate the format of the opponent formation
    if not re.match(r'^\d+(-\d+)+$', opponent_formation):
        return f"Error: Formation '{opponent_formation}' is not in the correct format (e.g., '3-4-2-1').", []
    
    if opponent_formation not in le.classes_:
        return f"Error: Formation '{opponent_formation}' not recognized.", []
    
    best_formation, evaluated_formations = recommend_formation_onnx(opponent_formation, ort_session, le, num_classes)
    return f"Recommended formation: {best_formation}", evaluated_formations

# Create the Gradio interface
iface = gr.Interface(
    fn=recommend,
    inputs=gr.Textbox(lines=1, placeholder="Enter opponent formation (e.g., '3-4-2-1')"),
    outputs=[
        gr.Textbox(label="Recommended Formation"),
        gr.Dataframe(headers=["Formation", "Score"], label="Evaluated Formations")
    ],
    title="Deepfield Proyecto Maradona E3 Football Formation Recommender",
    description="Enter the opponent formation to get the recommended formation and a list of evaluated formations with their scores."
)

# Launch the Gradio interface
iface.launch(share=True)