Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,8 +2,16 @@ import streamlit as st
|
|
2 |
import pandas as pd
|
3 |
import joblib
|
4 |
from sklearn.ensemble import RandomForestRegressor
|
5 |
-
import plotly.express as px
|
6 |
import plotly.graph_objects as go
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
# Mapping for position to numeric values
|
9 |
position_mapping = {
|
@@ -14,7 +22,7 @@ position_mapping = {
|
|
14 |
"C": 5.0,
|
15 |
}
|
16 |
|
17 |
-
#
|
18 |
injury_types = [
|
19 |
"foot fracture injury", "hip flexor surgery injury", "calf strain injury",
|
20 |
"quad injury injury", "shoulder sprain injury", "foot sprain injury",
|
@@ -28,10 +36,66 @@ injury_types = [
|
|
28 |
|
29 |
average_days_injured = {
|
30 |
"foot fracture injury": 207.666667,
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
"lower back spasm injury": 234.000000,
|
33 |
}
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
@st.cache_resource
|
36 |
def load_player_data():
|
37 |
return pd.read_csv("player_data.csv")
|
@@ -40,83 +104,151 @@ def load_player_data():
|
|
40 |
def load_rf_model():
|
41 |
return joblib.load("rf_injury_change_model.pkl")
|
42 |
|
43 |
-
# Main
|
|
|
44 |
def main():
|
45 |
st.title("NBA Player Performance Predictor 🏀")
|
46 |
-
st.
|
47 |
Use this tool to predict how a player's performance metrics might change
|
48 |
if they experience a hypothetical injury.
|
49 |
""")
|
50 |
|
|
|
51 |
player_data = load_player_data()
|
52 |
rf_model = load_rf_model()
|
53 |
|
54 |
-
#
|
55 |
with st.sidebar:
|
56 |
st.header("Player & Injury Inputs")
|
57 |
player_list = sorted(player_data['player_name'].dropna().unique())
|
58 |
player_name = st.selectbox("Select Player", player_list)
|
59 |
|
60 |
-
# Player Details
|
61 |
if player_name:
|
|
|
62 |
player_row = player_data[player_data['player_name'] == player_name]
|
|
|
63 |
position = player_row.iloc[0]['position']
|
64 |
-
position_numeric = position_mapping.get(position, 0)
|
65 |
-
st.write(f"**Position**: {position} (Numeric: {position_numeric})")
|
66 |
-
|
67 |
stats_columns = ['age', 'player_height', 'player_weight']
|
68 |
-
default_stats = {stat: player_row.iloc[0][stat] for stat in stats_columns}
|
69 |
|
|
|
|
|
|
|
|
|
|
|
70 |
for stat in default_stats.keys():
|
71 |
default_stats[stat] = st.number_input(f"{stat}", value=default_stats[stat])
|
72 |
|
|
|
73 |
injury_type = st.selectbox("Select Hypothetical Injury", injury_types)
|
74 |
default_days_injured = average_days_injured.get(injury_type, 30)
|
75 |
days_injured = st.slider("Estimated Days Injured", 0, 365, int(default_days_injured))
|
76 |
injury_occurrences = st.number_input("Injury Occurrences", min_value=0, value=1)
|
77 |
|
|
|
78 |
input_data = pd.DataFrame([{
|
79 |
"days_injured": days_injured,
|
80 |
"injury_occurrences": injury_occurrences,
|
81 |
-
"position":
|
82 |
"injury_type": injury_type,
|
83 |
**default_stats
|
84 |
}])
|
85 |
input_data["injury_type"] = pd.factorize(input_data["injury_type"])[0]
|
86 |
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
with col1:
|
91 |
st.subheader("Player Details")
|
92 |
st.metric("Age", default_stats['age'])
|
93 |
st.metric("Height (cm)", default_stats['player_height'])
|
94 |
st.metric("Weight (kg)", default_stats['player_weight'])
|
95 |
-
st.image("player_placeholder.png", caption=f"{player_name}", width=200) # Placeholder image
|
96 |
|
97 |
with col2:
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
if __name__ == "__main__":
|
122 |
main()
|
|
|
2 |
import pandas as pd
|
3 |
import joblib
|
4 |
from sklearn.ensemble import RandomForestRegressor
|
|
|
5 |
import plotly.graph_objects as go
|
6 |
+
from PIL import Image
|
7 |
+
import plotly.express as px
|
8 |
+
|
9 |
+
# Set the page configuration
|
10 |
+
st.set_page_config(
|
11 |
+
page_title="NBA Player Performance Predictor",
|
12 |
+
page_icon="🏀",
|
13 |
+
layout="centered"
|
14 |
+
)
|
15 |
|
16 |
# Mapping for position to numeric values
|
17 |
position_mapping = {
|
|
|
22 |
"C": 5.0,
|
23 |
}
|
24 |
|
25 |
+
# Injury types and average days
|
26 |
injury_types = [
|
27 |
"foot fracture injury", "hip flexor surgery injury", "calf strain injury",
|
28 |
"quad injury injury", "shoulder sprain injury", "foot sprain injury",
|
|
|
36 |
|
37 |
average_days_injured = {
|
38 |
"foot fracture injury": 207.666667,
|
39 |
+
"hip flexor surgery injury": 256.000000,
|
40 |
+
"calf strain injury": 236.000000,
|
41 |
+
"quad injury injury": 283.000000,
|
42 |
+
"shoulder sprain injury": 259.500000,
|
43 |
+
"foot sprain injury": 294.000000,
|
44 |
+
"torn rotator cuff injury injury": 251.500000,
|
45 |
+
"torn mcl injury": 271.000000,
|
46 |
+
"hip flexor strain injury": 253.000000,
|
47 |
+
"fractured leg injury": 250.250000,
|
48 |
+
"sprained mcl injury": 228.666667,
|
49 |
+
"ankle sprain injury": 231.333333,
|
50 |
+
"hamstring injury injury": 220.000000,
|
51 |
+
"meniscus tear injury": 201.250000,
|
52 |
+
"torn hamstring injury": 187.666667,
|
53 |
+
"dislocated shoulder injury": 269.000000,
|
54 |
+
"ankle fracture injury": 114.500000,
|
55 |
+
"fractured hand injury": 169.142857,
|
56 |
+
"bone spurs injury": 151.500000,
|
57 |
+
"acl tear injury": 268.000000,
|
58 |
+
"hip labrum injury": 247.500000,
|
59 |
+
"back surgery injury": 215.800000,
|
60 |
+
"arm injury injury": 303.666667,
|
61 |
+
"torn shoulder labrum injury": 195.666667,
|
62 |
"lower back spasm injury": 234.000000,
|
63 |
}
|
64 |
|
65 |
+
team_logo_paths = {
|
66 |
+
"Cleveland Cavaliers": "NBA_LOGOs/Clevelan-Cavaliers-logo-2022.png",
|
67 |
+
"Atlanta Hawks": "NBA_LOGOs/nba-atlanta-hawks-logo.png",
|
68 |
+
"Boston Celtics": "NBA_LOGOs/nba-boston-celtics-logo.png",
|
69 |
+
"Brooklyn Nets": "NBA_LOGOs/nba-brooklyn-nets-logo.png",
|
70 |
+
"Charlotte Hornets": "NBA_LOGOs/nba-charlotte-hornets-logo.png",
|
71 |
+
"Chicago Bulls": "NBA_LOGOs/nba-chicago-bulls-logo.png",
|
72 |
+
"Dallas Mavericks": "NBA_LOGOs/nba-dallas-mavericks-logo.png",
|
73 |
+
"Denver Nuggets": "NBA_LOGOs/nba-denver-nuggets-logo-2018.png",
|
74 |
+
"Detroit Pistons": "NBA_LOGOs/nba-detroit-pistons-logo.png",
|
75 |
+
"Golden State Warriors": "NBA_LOGOs/nba-golden-state-warriors-logo-2020.png",
|
76 |
+
"Houston Rockets": "NBA_LOGOs/nba-houston-rockets-logo-2020.png",
|
77 |
+
"Indiana Pacers": "NBA_LOGOs/nba-indiana-pacers-logo.png",
|
78 |
+
"LA Clippers": "NBA_LOGOs/nba-la-clippers-logo.png",
|
79 |
+
"Los Angeles Lakers": "NBA_LOGOs/nba-los-angeles-lakers-logo.png",
|
80 |
+
"Memphis Grizzlies": "NBA_LOGOs/nba-memphis-grizzlies-logo.png",
|
81 |
+
"Miami Heat": "NBA_LOGOs/nba-miami-heat-logo.png",
|
82 |
+
"Milwaukee Bucks": "NBA_LOGOs/nba-milwaukee-bucks-logo.png",
|
83 |
+
"Minnesota Timberwolves": "NBA_LOGOs/nba-minnesota-timberwolves-logo.png",
|
84 |
+
"New Orleans Pelicans": "NBA_LOGOs/nba-new-orleans-pelicans-logo.png",
|
85 |
+
"New York Knicks": "NBA_LOGOs/nba-new-york-knicks-logo.png",
|
86 |
+
"Oklahoma City Thunder": "NBA_LOGOs/nba-oklahoma-city-thunder-logo.png",
|
87 |
+
"Orlando Magic": "NBA_LOGOs/nba-orlando-magic-logo.png",
|
88 |
+
"Philadelphia 76ers": "NBA_LOGOs/nba-philadelphia-76ers-logo.png",
|
89 |
+
"Phoenix Suns": "NBA_LOGOs/nba-phoenix-suns-logo.png",
|
90 |
+
"Portland Trail Blazers": "NBA_LOGOs/nba-portland-trail-blazers-logo.png",
|
91 |
+
"Sacramento Kings": "NBA_LOGOs/nba-sacramento-kings-logo.png",
|
92 |
+
"San Antonio Spurs": "NBA_LOGOs/nba-san-antonio-spurs-logo.png",
|
93 |
+
"Toronto Raptors": "NBA_LOGOs/nba-toronto-raptors-logo-2020.png",
|
94 |
+
"Utah Jazz": "NBA_LOGOs/nba-utah-jazz-logo.png",
|
95 |
+
"Washington Wizards": "NBA_LOGOs/nba-washington-wizards-logo.png",
|
96 |
+
}
|
97 |
+
|
98 |
+
# Caching player data and model
|
99 |
@st.cache_resource
|
100 |
def load_player_data():
|
101 |
return pd.read_csv("player_data.csv")
|
|
|
104 |
def load_rf_model():
|
105 |
return joblib.load("rf_injury_change_model.pkl")
|
106 |
|
107 |
+
# Main application
|
108 |
+
# Main application
|
109 |
def main():
|
110 |
st.title("NBA Player Performance Predictor 🏀")
|
111 |
+
st.markdown("""
|
112 |
Use this tool to predict how a player's performance metrics might change
|
113 |
if they experience a hypothetical injury.
|
114 |
""")
|
115 |
|
116 |
+
# Load data and model
|
117 |
player_data = load_player_data()
|
118 |
rf_model = load_rf_model()
|
119 |
|
120 |
+
# Sidebar inputs
|
121 |
with st.sidebar:
|
122 |
st.header("Player & Injury Inputs")
|
123 |
player_list = sorted(player_data['player_name'].dropna().unique())
|
124 |
player_name = st.selectbox("Select Player", player_list)
|
125 |
|
|
|
126 |
if player_name:
|
127 |
+
# Filter data for the selected player
|
128 |
player_row = player_data[player_data['player_name'] == player_name]
|
129 |
+
team_name = player_row.iloc[0]['team_abbreviation']
|
130 |
position = player_row.iloc[0]['position']
|
|
|
|
|
|
|
131 |
stats_columns = ['age', 'player_height', 'player_weight']
|
|
|
132 |
|
133 |
+
st.write(f"**Position**: {position}")
|
134 |
+
st.write(f"**Team**: {team_name}")
|
135 |
+
|
136 |
+
# Player stats
|
137 |
+
default_stats = {stat: player_row.iloc[0][stat] for stat in stats_columns}
|
138 |
for stat in default_stats.keys():
|
139 |
default_stats[stat] = st.number_input(f"{stat}", value=default_stats[stat])
|
140 |
|
141 |
+
# Injury details
|
142 |
injury_type = st.selectbox("Select Hypothetical Injury", injury_types)
|
143 |
default_days_injured = average_days_injured.get(injury_type, 30)
|
144 |
days_injured = st.slider("Estimated Days Injured", 0, 365, int(default_days_injured))
|
145 |
injury_occurrences = st.number_input("Injury Occurrences", min_value=0, value=1)
|
146 |
|
147 |
+
# Prepare data for prediction
|
148 |
input_data = pd.DataFrame([{
|
149 |
"days_injured": days_injured,
|
150 |
"injury_occurrences": injury_occurrences,
|
151 |
+
"position": position_mapping.get(position, 0),
|
152 |
"injury_type": injury_type,
|
153 |
**default_stats
|
154 |
}])
|
155 |
input_data["injury_type"] = pd.factorize(input_data["injury_type"])[0]
|
156 |
|
157 |
+
st.header("Prediction Results")
|
158 |
+
if st.button("Predict"):
|
159 |
+
predictions = rf_model.predict(input_data)
|
160 |
+
predictions = [round(float(pred), 2) for pred in predictions]
|
161 |
+
|
162 |
+
# Display prediction results
|
163 |
+
prediction_columns = ["Predicted Change in PTS", "Predicted Change in REB", "Predicted Change in AST"]
|
164 |
+
result_df = pd.DataFrame([predictions], columns=prediction_columns)
|
165 |
+
st.table(result_df)
|
166 |
+
|
167 |
+
# Main content layout
|
168 |
+
st.divider()
|
169 |
+
st.header("Player Overview")
|
170 |
+
col1, col2 = st.columns([1, 2])
|
171 |
|
172 |
with col1:
|
173 |
st.subheader("Player Details")
|
174 |
st.metric("Age", default_stats['age'])
|
175 |
st.metric("Height (cm)", default_stats['player_height'])
|
176 |
st.metric("Weight (kg)", default_stats['player_weight'])
|
|
|
177 |
|
178 |
with col2:
|
179 |
+
# Display team logo
|
180 |
+
if team_name in team_logo_paths:
|
181 |
+
logo_path = team_logo_paths[team_name]
|
182 |
+
try:
|
183 |
+
logo_image = Image.open(logo_path)
|
184 |
+
st.image(logo_image, caption=f"{team_name} Logo", use_column_width=True)
|
185 |
+
except FileNotFoundError:
|
186 |
+
st.error(f"Logo for {team_name} not found.")
|
187 |
+
|
188 |
+
|
189 |
+
# Graphs for PPG, AST, and REB
|
190 |
+
st.divider()
|
191 |
+
st.header("Player Performance Graphs")
|
192 |
+
|
193 |
+
if st.button("Show Performance Graphs"):
|
194 |
+
# Filter data for the selected player
|
195 |
+
player_data_filtered = player_data[player_data["player_name"] == player_name].sort_values(by="season")
|
196 |
+
|
197 |
+
# Ensure all seasons are included
|
198 |
+
all_seasons = pd.Series(range(player_data["season"].min(), player_data["season"].max() + 1))
|
199 |
+
player_data_filtered = (
|
200 |
+
pd.DataFrame({"season": all_seasons})
|
201 |
+
.merge(player_data_filtered, on="season", how="left")
|
202 |
+
.fillna({"pts": 0, "ast": 0, "reb": 0}) # Fill missing values
|
203 |
+
)
|
204 |
+
|
205 |
+
if not player_data_filtered.empty:
|
206 |
+
# PPG Graph
|
207 |
+
fig_ppg = px.line(
|
208 |
+
player_data_filtered,
|
209 |
+
x="season",
|
210 |
+
y="pts",
|
211 |
+
title=f"{player_name}: Points Per Game (PPG) Over Seasons",
|
212 |
+
labels={"pts": "Points Per Game (PPG)", "season": "Season"},
|
213 |
+
markers=True
|
214 |
+
)
|
215 |
+
fig_ppg.update_layout(template="plotly_white")
|
216 |
+
|
217 |
+
# AST Graph
|
218 |
+
fig_ast = px.line(
|
219 |
+
player_data_filtered,
|
220 |
+
x="season",
|
221 |
+
y="ast",
|
222 |
+
title=f"{player_name}: Assists Per Game (AST) Over Seasons",
|
223 |
+
labels={"ast": "Assists Per Game (AST)", "season": "Season"},
|
224 |
+
markers=True
|
225 |
+
)
|
226 |
+
fig_ast.update_layout(template="plotly_white")
|
227 |
+
|
228 |
+
# REB Graph
|
229 |
+
fig_reb = px.line(
|
230 |
+
player_data_filtered,
|
231 |
+
x="season",
|
232 |
+
y="reb",
|
233 |
+
title=f"{player_name}: Rebounds Per Game (REB) Over Seasons",
|
234 |
+
labels={"reb": "Rebounds Per Game (REB)", "season": "Season"},
|
235 |
+
markers=True
|
236 |
+
)
|
237 |
+
fig_reb.update_layout(template="plotly_white")
|
238 |
+
|
239 |
+
# Display graphs
|
240 |
+
st.plotly_chart(fig_ppg, use_container_width=True)
|
241 |
+
st.plotly_chart(fig_ast, use_container_width=True)
|
242 |
+
st.plotly_chart(fig_reb, use_container_width=True)
|
243 |
+
else:
|
244 |
+
st.error("No data available for the selected player.")
|
245 |
+
|
246 |
+
# Footer
|
247 |
+
st.divider()
|
248 |
+
st.markdown("""
|
249 |
+
### About This Tool
|
250 |
+
This application predicts how injuries might impact an NBA player's performance using machine learning models. Data is based on historical player stats and injuries.
|
251 |
+
""")
|
252 |
|
253 |
if __name__ == "__main__":
|
254 |
main()
|