Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,14 +1,12 @@
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
import joblib
|
4 |
-
from sklearn.ensemble import RandomForestRegressor
|
5 |
-
import plotly.graph_objects as go
|
6 |
-
from PIL import Image
|
7 |
import plotly.express as px
|
|
|
8 |
|
9 |
# Set the page configuration
|
10 |
st.set_page_config(
|
11 |
-
page_title="NBA Player Performance Predictor",
|
12 |
page_icon="π",
|
13 |
layout="centered"
|
14 |
)
|
@@ -63,36 +61,36 @@ average_days_injured = {
|
|
63 |
}
|
64 |
|
65 |
team_logo_paths = {
|
66 |
-
"Cleveland Cavaliers": "Clevelan-Cavaliers-logo-2022.png",
|
67 |
-
"Atlanta Hawks": "nba-atlanta-hawks-logo.png",
|
68 |
-
"Boston Celtics": "nba-boston-celtics-logo.png",
|
69 |
-
"Brooklyn Nets": "nba-brooklyn-nets-logo.png",
|
70 |
-
"Charlotte Hornets": "nba-charlotte-hornets-logo.png",
|
71 |
-
"Chicago Bulls": "nba-chicago-bulls-logo.png",
|
72 |
-
"Dallas Mavericks": "nba-dallas-mavericks-logo.png",
|
73 |
-
"Denver Nuggets": "nba-denver-nuggets-logo-2018.png",
|
74 |
-
"Detroit Pistons": "nba-detroit-pistons-logo.png",
|
75 |
-
"Golden State Warriors": "nba-golden-state-warriors-logo-2020.png",
|
76 |
-
"Houston Rockets": "nba-houston-rockets-logo-2020.png",
|
77 |
-
"Indiana Pacers": "nba-indiana-pacers-logo.png",
|
78 |
-
"LA Clippers": "nba-la-clippers-logo.png",
|
79 |
-
"Los Angeles Lakers": "nba-los-angeles-lakers-logo.png",
|
80 |
-
"Memphis Grizzlies": "nba-memphis-grizzlies-logo.png",
|
81 |
-
"Miami Heat": "nba-miami-heat-logo.png",
|
82 |
-
"Milwaukee Bucks": "nba-milwaukee-bucks-logo.png",
|
83 |
-
"Minnesota Timberwolves": "nba-minnesota-timberwolves-logo.png",
|
84 |
-
"New Orleans Pelicans": "nba-new-orleans-pelicans-logo.png",
|
85 |
-
"New York Knicks": "nba-new-york-knicks-logo.png",
|
86 |
-
"Oklahoma City Thunder": "nba-oklahoma-city-thunder-logo.png",
|
87 |
-
"Orlando Magic": "nba-orlando-magic-logo.png",
|
88 |
-
"Philadelphia 76ers": "nba-philadelphia-76ers-logo.png",
|
89 |
-
"Phoenix Suns": "nba-phoenix-suns-logo.png",
|
90 |
-
"Portland Trail Blazers": "nba-portland-trail-blazers-logo.png",
|
91 |
-
"Sacramento Kings": "nba-sacramento-kings-logo.png",
|
92 |
-
"San Antonio Spurs": "nba-san-antonio-spurs-logo.png",
|
93 |
-
"Toronto Raptors": "nba-toronto-raptors-logo-2020.png",
|
94 |
-
"Utah Jazz": "nba-utah-jazz-logo.png",
|
95 |
-
"Washington Wizards": "nba-washington-wizards-logo.png",
|
96 |
}
|
97 |
|
98 |
# Caching player data and model
|
@@ -104,14 +102,15 @@ def load_player_data():
|
|
104 |
def load_rf_model():
|
105 |
return joblib.load("rf_injury_change_model.pkl")
|
106 |
|
107 |
-
# Main application
|
108 |
# Main application
|
109 |
def main():
|
110 |
st.title("NBA Player Performance Predictor π")
|
111 |
-
st.markdown(
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
115 |
|
116 |
# Load data and model
|
117 |
player_data = load_player_data()
|
@@ -152,103 +151,32 @@ def main():
|
|
152 |
"injury_type": injury_type,
|
153 |
**default_stats
|
154 |
}])
|
155 |
-
input_data["injury_type"] = pd.factorize(input_data["injury_type"])[0]
|
156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
st.header("Prediction Results")
|
158 |
-
if st.button("Predict"):
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
st.header("Player Overview")
|
170 |
-
col1, col2 = st.columns([1, 2])
|
171 |
-
|
172 |
-
with col1:
|
173 |
-
st.subheader("Player Details")
|
174 |
-
st.metric("Age", default_stats['age'])
|
175 |
-
st.metric("Height (cm)", default_stats['player_height'])
|
176 |
-
st.metric("Weight (kg)", default_stats['player_weight'])
|
177 |
-
|
178 |
-
with col2:
|
179 |
-
# Display team logo
|
180 |
-
if team_name in team_logo_paths:
|
181 |
-
logo_path = team_logo_paths[team_name]
|
182 |
-
try:
|
183 |
-
logo_image = Image.open(logo_path)
|
184 |
-
st.image(logo_image, caption=f"{team_name} Logo", use_container_width=True)
|
185 |
-
except FileNotFoundError:
|
186 |
-
st.error(f"Logo for {team_name} not found.")
|
187 |
-
|
188 |
-
|
189 |
-
# Graphs for PPG, AST, and REB
|
190 |
-
st.divider()
|
191 |
-
st.header("Player Performance Graphs")
|
192 |
-
|
193 |
-
if st.button("Show Performance Graphs"):
|
194 |
-
# Filter data for the selected player
|
195 |
-
player_data_filtered = player_data[player_data["player_name"] == player_name].sort_values(by="season")
|
196 |
-
|
197 |
-
# Ensure all seasons are included
|
198 |
-
all_seasons = pd.Series(range(player_data["season"].min(), player_data["season"].max() + 1))
|
199 |
-
player_data_filtered = (
|
200 |
-
pd.DataFrame({"season": all_seasons})
|
201 |
-
.merge(player_data_filtered, on="season", how="left")
|
202 |
-
.fillna({"pts": 0, "ast": 0, "reb": 0}) # Fill missing values
|
203 |
-
)
|
204 |
-
|
205 |
-
if not player_data_filtered.empty:
|
206 |
-
# PPG Graph
|
207 |
-
fig_ppg = px.line(
|
208 |
-
player_data_filtered,
|
209 |
-
x="season",
|
210 |
-
y="pts",
|
211 |
-
title=f"{player_name}: Points Per Game (PPG) Over Seasons",
|
212 |
-
labels={"pts": "Points Per Game (PPG)", "season": "Season"},
|
213 |
-
markers=True
|
214 |
-
)
|
215 |
-
fig_ppg.update_layout(template="plotly_white")
|
216 |
-
|
217 |
-
# AST Graph
|
218 |
-
fig_ast = px.line(
|
219 |
-
player_data_filtered,
|
220 |
-
x="season",
|
221 |
-
y="ast",
|
222 |
-
title=f"{player_name}: Assists Per Game (AST) Over Seasons",
|
223 |
-
labels={"ast": "Assists Per Game (AST)", "season": "Season"},
|
224 |
-
markers=True
|
225 |
-
)
|
226 |
-
fig_ast.update_layout(template="plotly_white")
|
227 |
-
|
228 |
-
# REB Graph
|
229 |
-
fig_reb = px.line(
|
230 |
-
player_data_filtered,
|
231 |
-
x="season",
|
232 |
-
y="reb",
|
233 |
-
title=f"{player_name}: Rebounds Per Game (REB) Over Seasons",
|
234 |
-
labels={"reb": "Rebounds Per Game (REB)", "season": "Season"},
|
235 |
-
markers=True
|
236 |
-
)
|
237 |
-
fig_reb.update_layout(template="plotly_white")
|
238 |
-
|
239 |
-
# Display graphs
|
240 |
-
st.plotly_chart(fig_ppg, use_container_width=True)
|
241 |
-
st.plotly_chart(fig_ast, use_container_width=True)
|
242 |
-
st.plotly_chart(fig_reb, use_container_width=True)
|
243 |
-
else:
|
244 |
-
st.error("No data available for the selected player.")
|
245 |
-
|
246 |
-
# Footer
|
247 |
-
st.divider()
|
248 |
-
st.markdown("""
|
249 |
-
### About This Tool
|
250 |
-
This application predicts how injuries might impact an NBA player's performance using machine learning models. Data is based on historical player stats and injuries.
|
251 |
-
""")
|
252 |
|
253 |
if __name__ == "__main__":
|
254 |
main()
|
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
import joblib
|
|
|
|
|
|
|
4 |
import plotly.express as px
|
5 |
+
from PIL import Image
|
6 |
|
7 |
# Set the page configuration
|
8 |
st.set_page_config(
|
9 |
+
page_title="NBA Player Performance Predictor π",
|
10 |
page_icon="π",
|
11 |
layout="centered"
|
12 |
)
|
|
|
61 |
}
|
62 |
|
63 |
team_logo_paths = {
|
64 |
+
"Cleveland Cavaliers": "NBA_LOGOs/Clevelan-Cavaliers-logo-2022.png",
|
65 |
+
"Atlanta Hawks": "NBA_LOGOs/nba-atlanta-hawks-logo.png",
|
66 |
+
"Boston Celtics": "NBA_LOGOs/nba-boston-celtics-logo.png",
|
67 |
+
"Brooklyn Nets": "NBA_LOGOs/nba-brooklyn-nets-logo.png",
|
68 |
+
"Charlotte Hornets": "NBA_LOGOs/nba-charlotte-hornets-logo.png",
|
69 |
+
"Chicago Bulls": "NBA_LOGOs/nba-chicago-bulls-logo.png",
|
70 |
+
"Dallas Mavericks": "NBA_LOGOs/nba-dallas-mavericks-logo.png",
|
71 |
+
"Denver Nuggets": "NBA_LOGOs/nba-denver-nuggets-logo-2018.png",
|
72 |
+
"Detroit Pistons": "NBA_LOGOs/nba-detroit-pistons-logo.png",
|
73 |
+
"Golden State Warriors": "NBA_LOGOs/nba-golden-state-warriors-logo-2020.png",
|
74 |
+
"Houston Rockets": "NBA_LOGOs/nba-houston-rockets-logo-2020.png",
|
75 |
+
"Indiana Pacers": "NBA_LOGOs/nba-indiana-pacers-logo.png",
|
76 |
+
"LA Clippers": "NBA_LOGOs/nba-la-clippers-logo.png",
|
77 |
+
"Los Angeles Lakers": "NBA_LOGOs/nba-los-angeles-lakers-logo.png",
|
78 |
+
"Memphis Grizzlies": "NBA_LOGOs/nba-memphis-grizzlies-logo.png",
|
79 |
+
"Miami Heat": "NBA_LOGOs/nba-miami-heat-logo.png",
|
80 |
+
"Milwaukee Bucks": "NBA_LOGOs/nba-milwaukee-bucks-logo.png",
|
81 |
+
"Minnesota Timberwolves": "NBA_LOGOs/nba-minnesota-timberwolves-logo.png",
|
82 |
+
"New Orleans Pelicans": "NBA_LOGOs/nba-new-orleans-pelicans-logo.png",
|
83 |
+
"New York Knicks": "NBA_LOGOs/nba-new-york-knicks-logo.png",
|
84 |
+
"Oklahoma City Thunder": "NBA_LOGOs/nba-oklahoma-city-thunder-logo.png",
|
85 |
+
"Orlando Magic": "NBA_LOGOs/nba-orlando-magic-logo.png",
|
86 |
+
"Philadelphia 76ers": "NBA_LOGOs/nba-philadelphia-76ers-logo.png",
|
87 |
+
"Phoenix Suns": "NBA_LOGOs/nba-phoenix-suns-logo.png",
|
88 |
+
"Portland Trail Blazers": "NBA_LOGOs/nba-portland-trail-blazers-logo.png",
|
89 |
+
"Sacramento Kings": "NBA_LOGOs/nba-sacramento-kings-logo.png",
|
90 |
+
"San Antonio Spurs": "NBA_LOGOs/nba-san-antonio-spurs-logo.png",
|
91 |
+
"Toronto Raptors": "NBA_LOGOs/nba-toronto-raptors-logo-2020.png",
|
92 |
+
"Utah Jazz": "NBA_LOGOs/nba-utah-jazz-logo.png",
|
93 |
+
"Washington Wizards": "NBA_LOGOs/nba-washington-wizards-logo.png",
|
94 |
}
|
95 |
|
96 |
# Caching player data and model
|
|
|
102 |
def load_rf_model():
|
103 |
return joblib.load("rf_injury_change_model.pkl")
|
104 |
|
|
|
105 |
# Main application
|
106 |
def main():
|
107 |
st.title("NBA Player Performance Predictor π")
|
108 |
+
st.markdown(
|
109 |
+
"""
|
110 |
+
Use this tool to predict how a player's performance metrics might change
|
111 |
+
if they experience a hypothetical injury.
|
112 |
+
"""
|
113 |
+
)
|
114 |
|
115 |
# Load data and model
|
116 |
player_data = load_player_data()
|
|
|
151 |
"injury_type": injury_type,
|
152 |
**default_stats
|
153 |
}])
|
|
|
154 |
|
155 |
+
# One-hot encode injury type
|
156 |
+
input_data = pd.get_dummies(input_data, columns=["injury_type"], drop_first=True)
|
157 |
+
|
158 |
+
# Align with model's feature names
|
159 |
+
expected_features = rf_model.feature_names_in_
|
160 |
+
for feature in expected_features:
|
161 |
+
if feature not in input_data.columns:
|
162 |
+
input_data[feature] = 0
|
163 |
+
|
164 |
+
# Ensure columns are in the same order as the model's feature names
|
165 |
+
input_data = input_data[expected_features]
|
166 |
+
|
167 |
+
# Predict and display results
|
168 |
st.header("Prediction Results")
|
169 |
+
if st.sidebar.button("Predict"):
|
170 |
+
try:
|
171 |
+
predictions = rf_model.predict(input_data)
|
172 |
+
prediction_columns = ["Predicted Change in PTS", "Predicted Change in REB", "Predicted Change in AST"]
|
173 |
+
st.subheader("Predicted Post-Injury Performance")
|
174 |
+
st.write("Based on the inputs, here are the predicted metrics:")
|
175 |
+
st.table(pd.DataFrame(predictions, columns=prediction_columns))
|
176 |
+
except FileNotFoundError:
|
177 |
+
st.error("Model file not found.")
|
178 |
+
except ValueError as e:
|
179 |
+
st.error(f"Error during prediction: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
if __name__ == "__main__":
|
182 |
main()
|