Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -3,180 +3,120 @@ import pandas as pd
|
|
3 |
import joblib
|
4 |
from sklearn.ensemble import RandomForestRegressor
|
5 |
import plotly.express as px
|
|
|
6 |
|
7 |
# Mapping for position to numeric values
|
8 |
position_mapping = {
|
9 |
-
"PG": 1.0,
|
10 |
-
"SG": 2.0,
|
11 |
-
"SF": 3.0,
|
12 |
-
"PF": 4.0,
|
13 |
-
"C": 5.0,
|
14 |
}
|
15 |
|
16 |
# Predefined injury types
|
17 |
injury_types = [
|
18 |
-
"foot fracture injury",
|
19 |
-
"
|
20 |
-
"
|
21 |
-
"
|
22 |
-
"
|
23 |
-
"
|
24 |
-
"
|
25 |
-
"torn
|
26 |
-
"hip flexor strain injury",
|
27 |
-
"fractured leg injury",
|
28 |
-
"sprained mcl injury",
|
29 |
-
"ankle sprain injury",
|
30 |
-
"hamstring injury injury",
|
31 |
-
"meniscus tear injury",
|
32 |
-
"torn hamstring injury",
|
33 |
-
"dislocated shoulder injury",
|
34 |
-
"ankle fracture injury",
|
35 |
-
"fractured hand injury",
|
36 |
-
"bone spurs injury",
|
37 |
-
"acl tear injury",
|
38 |
-
"hip labrum injury",
|
39 |
-
"back surgery injury",
|
40 |
-
"arm injury injury",
|
41 |
-
"torn shoulder labrum injury",
|
42 |
-
"lower back spasm injury"
|
43 |
]
|
44 |
|
45 |
-
# Injury average days dictionary
|
46 |
average_days_injured = {
|
47 |
"foot fracture injury": 207.666667,
|
48 |
-
|
49 |
-
"calf strain injury": 236.000000,
|
50 |
-
"quad injury injury": 283.000000,
|
51 |
-
"shoulder sprain injury": 259.500000,
|
52 |
-
"foot sprain injury": 294.000000,
|
53 |
-
"torn rotator cuff injury injury": 251.500000,
|
54 |
-
"torn mcl injury": 271.000000,
|
55 |
-
"hip flexor strain injury": 253.000000,
|
56 |
-
"fractured leg injury": 250.250000,
|
57 |
-
"sprained mcl injury": 228.666667,
|
58 |
-
"ankle sprain injury": 231.333333,
|
59 |
-
"hamstring injury injury": 220.000000,
|
60 |
-
"meniscus tear injury": 201.250000,
|
61 |
-
"torn hamstring injury": 187.666667,
|
62 |
-
"dislocated shoulder injury": 269.000000,
|
63 |
-
"ankle fracture injury": 114.500000,
|
64 |
-
"fractured hand injury": 169.142857,
|
65 |
-
"bone spurs injury": 151.500000,
|
66 |
-
"acl tear injury": 268.000000,
|
67 |
-
"hip labrum injury": 247.500000,
|
68 |
-
"back surgery injury": 215.800000,
|
69 |
-
"arm injury injury": 303.666667,
|
70 |
-
"torn shoulder labrum injury": 195.666667,
|
71 |
"lower back spasm injury": 234.000000,
|
72 |
}
|
73 |
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
# Load player dataset
|
79 |
@st.cache_resource
|
80 |
def load_player_data():
|
81 |
return pd.read_csv("player_data.csv")
|
82 |
|
83 |
-
# Load Random Forest model
|
84 |
@st.cache_resource
|
85 |
def load_rf_model():
|
86 |
return joblib.load("rf_injury_change_model.pkl")
|
87 |
|
88 |
# Main Streamlit app
|
89 |
def main():
|
90 |
-
st.title("NBA Player Performance Predictor")
|
91 |
-
st.write(
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
)
|
97 |
-
|
98 |
-
# Load player data
|
99 |
player_data = load_player_data()
|
100 |
rf_model = load_rf_model()
|
101 |
|
102 |
-
# Sidebar
|
103 |
-
st.sidebar
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
player_name = st.sidebar.selectbox("Select Player", player_list)
|
108 |
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
if not player_row.empty:
|
114 |
position = player_row.iloc[0]['position']
|
115 |
position_numeric = position_mapping.get(position, 0)
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
# Default values for features
|
120 |
stats_columns = ['age', 'player_height', 'player_weight']
|
121 |
-
default_stats = {
|
122 |
-
stat: player_row.iloc[0][stat] if stat in player_row.columns else 0
|
123 |
-
for stat in stats_columns
|
124 |
-
}
|
125 |
|
126 |
-
# Allow manual adjustment of stats
|
127 |
for stat in default_stats.keys():
|
128 |
-
default_stats[stat] = st.
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
"Estimated Days Injured",
|
136 |
-
0,
|
137 |
-
365,
|
138 |
-
int(default_days_injured),
|
139 |
-
help=f"Default days for {injury_type}: {int(default_days_injured) if default_days_injured else 'N/A'}"
|
140 |
-
)
|
141 |
-
injury_occurrences = st.sidebar.number_input("Injury Occurrences", min_value=0, value=1)
|
142 |
-
|
143 |
-
# Prepare input data
|
144 |
input_data = pd.DataFrame([{
|
145 |
"days_injured": days_injured,
|
146 |
"injury_occurrences": injury_occurrences,
|
147 |
"position": position_numeric,
|
148 |
-
"injury_type": injury_type,
|
149 |
**default_stats
|
150 |
}])
|
151 |
-
|
152 |
-
# Encode injury type
|
153 |
input_data["injury_type"] = pd.factorize(input_data["injury_type"])[0]
|
154 |
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
if __name__ == "__main__":
|
182 |
main()
|
|
|
3 |
import joblib
|
4 |
from sklearn.ensemble import RandomForestRegressor
|
5 |
import plotly.express as px
|
6 |
+
import plotly.graph_objects as go
|
7 |
|
8 |
# Mapping for position to numeric values
|
9 |
position_mapping = {
|
10 |
+
"PG": 1.0,
|
11 |
+
"SG": 2.0,
|
12 |
+
"SF": 3.0,
|
13 |
+
"PF": 4.0,
|
14 |
+
"C": 5.0,
|
15 |
}
|
16 |
|
17 |
# Predefined injury types
|
18 |
injury_types = [
|
19 |
+
"foot fracture injury", "hip flexor surgery injury", "calf strain injury",
|
20 |
+
"quad injury injury", "shoulder sprain injury", "foot sprain injury",
|
21 |
+
"torn rotator cuff injury injury", "torn mcl injury", "hip flexor strain injury",
|
22 |
+
"fractured leg injury", "sprained mcl injury", "ankle sprain injury",
|
23 |
+
"hamstring injury injury", "meniscus tear injury", "torn hamstring injury",
|
24 |
+
"dislocated shoulder injury", "ankle fracture injury", "fractured hand injury",
|
25 |
+
"bone spurs injury", "acl tear injury", "hip labrum injury", "back surgery injury",
|
26 |
+
"arm injury injury", "torn shoulder labrum injury", "lower back spasm injury"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
]
|
28 |
|
|
|
29 |
average_days_injured = {
|
30 |
"foot fracture injury": 207.666667,
|
31 |
+
# ... (other injuries with their averages)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
"lower back spasm injury": 234.000000,
|
33 |
}
|
34 |
|
|
|
|
|
|
|
|
|
|
|
35 |
@st.cache_resource
|
36 |
def load_player_data():
|
37 |
return pd.read_csv("player_data.csv")
|
38 |
|
|
|
39 |
@st.cache_resource
|
40 |
def load_rf_model():
|
41 |
return joblib.load("rf_injury_change_model.pkl")
|
42 |
|
43 |
# Main Streamlit app
|
44 |
def main():
|
45 |
+
st.title("NBA Player Performance Predictor 🏀")
|
46 |
+
st.write("""
|
47 |
+
Use this tool to predict how a player's performance metrics might change
|
48 |
+
if they experience a hypothetical injury.
|
49 |
+
""")
|
50 |
+
|
|
|
|
|
|
|
51 |
player_data = load_player_data()
|
52 |
rf_model = load_rf_model()
|
53 |
|
54 |
+
# Layout: Sidebar for Inputs
|
55 |
+
with st.sidebar:
|
56 |
+
st.header("Player & Injury Inputs")
|
57 |
+
player_list = sorted(player_data['player_name'].dropna().unique())
|
58 |
+
player_name = st.selectbox("Select Player", player_list)
|
|
|
59 |
|
60 |
+
# Player Details
|
61 |
+
if player_name:
|
62 |
+
player_row = player_data[player_data['player_name'] == player_name]
|
|
|
|
|
63 |
position = player_row.iloc[0]['position']
|
64 |
position_numeric = position_mapping.get(position, 0)
|
65 |
+
st.write(f"**Position**: {position} (Numeric: {position_numeric})")
|
66 |
+
|
|
|
|
|
67 |
stats_columns = ['age', 'player_height', 'player_weight']
|
68 |
+
default_stats = {stat: player_row.iloc[0][stat] for stat in stats_columns}
|
|
|
|
|
|
|
69 |
|
|
|
70 |
for stat in default_stats.keys():
|
71 |
+
default_stats[stat] = st.number_input(f"{stat}", value=default_stats[stat])
|
72 |
+
|
73 |
+
injury_type = st.selectbox("Select Hypothetical Injury", injury_types)
|
74 |
+
default_days_injured = average_days_injured.get(injury_type, 30)
|
75 |
+
days_injured = st.slider("Estimated Days Injured", 0, 365, int(default_days_injured))
|
76 |
+
injury_occurrences = st.number_input("Injury Occurrences", min_value=0, value=1)
|
77 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
input_data = pd.DataFrame([{
|
79 |
"days_injured": days_injured,
|
80 |
"injury_occurrences": injury_occurrences,
|
81 |
"position": position_numeric,
|
82 |
+
"injury_type": injury_type,
|
83 |
**default_stats
|
84 |
}])
|
|
|
|
|
85 |
input_data["injury_type"] = pd.factorize(input_data["injury_type"])[0]
|
86 |
|
87 |
+
# Main Layout
|
88 |
+
col1, col2 = st.columns(2)
|
89 |
+
|
90 |
+
with col1:
|
91 |
+
st.subheader("Player Details")
|
92 |
+
st.metric("Age", default_stats['age'])
|
93 |
+
st.metric("Height (cm)", default_stats['player_height'])
|
94 |
+
st.metric("Weight (kg)", default_stats['player_weight'])
|
95 |
+
st.image("player_placeholder.png", caption=f"{player_name}", width=200) # Placeholder image
|
96 |
+
|
97 |
+
with col2:
|
98 |
+
st.subheader("Prediction Results")
|
99 |
+
|
100 |
+
if st.button("Predict"):
|
101 |
+
predictions = rf_model.predict(input_data)
|
102 |
+
predictions = [round(float(pred), 2) for pred in predictions]
|
103 |
+
|
104 |
+
prediction_columns = ["Predicted Change in PTS", "Predicted Change in REB", "Predicted Change in AST"]
|
105 |
+
result_df = pd.DataFrame([predictions], columns=prediction_columns)
|
106 |
+
|
107 |
+
# Bar Chart for Visualizing Predictions
|
108 |
+
fig = go.Figure(data=[
|
109 |
+
go.Bar(
|
110 |
+
x=prediction_columns,
|
111 |
+
y=predictions,
|
112 |
+
marker_color=["green" if val > 0 else "red" for val in predictions]
|
113 |
+
)
|
114 |
+
])
|
115 |
+
fig.update_layout(title="Predicted Performance Changes", xaxis_title="Metrics", yaxis_title="Change")
|
116 |
+
st.plotly_chart(fig)
|
117 |
+
|
118 |
+
# Display predictions as a table
|
119 |
+
st.table(result_df)
|
120 |
|
121 |
if __name__ == "__main__":
|
122 |
main()
|