Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -3,6 +3,52 @@ import pandas as pd
|
|
3 |
import joblib
|
4 |
from sklearn.ensemble import RandomForestRegressor
|
5 |
import plotly.express as px
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
# Mapping for position to numeric values
|
8 |
position_mapping = {
|
@@ -72,18 +118,15 @@ average_days_injured = {
|
|
72 |
}
|
73 |
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
# Load player dataset
|
79 |
@st.cache_resource
|
80 |
def load_player_data():
|
81 |
-
return pd.read_csv("
|
82 |
|
83 |
# Load Random Forest model
|
84 |
@st.cache_resource
|
85 |
def load_rf_model():
|
86 |
-
return joblib.load("
|
87 |
|
88 |
# Main Streamlit app
|
89 |
def main():
|
@@ -109,6 +152,8 @@ def main():
|
|
109 |
if player_name:
|
110 |
# Retrieve player details
|
111 |
player_row = player_data[player_data['player_name'] == player_name]
|
|
|
|
|
112 |
|
113 |
if not player_row.empty:
|
114 |
position = player_row.iloc[0]['position']
|
@@ -177,5 +222,90 @@ def main():
|
|
177 |
else:
|
178 |
st.sidebar.error("Please select a player to view details.")
|
179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
if __name__ == "__main__":
|
181 |
main()
|
|
|
3 |
import joblib
|
4 |
from sklearn.ensemble import RandomForestRegressor
|
5 |
import plotly.express as px
|
6 |
+
from sklearn.ensemble import RandomForestRegressor
|
7 |
+
import plotly.graph_objects as go
|
8 |
+
from PIL import Image
|
9 |
+
import plotly.express as px
|
10 |
+
|
11 |
+
# Set the page configuration
|
12 |
+
st.set_page_config(
|
13 |
+
page_title="NBA Player Performance Predictor",
|
14 |
+
page_icon="🏀",
|
15 |
+
layout="centered"
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
team_logo_paths = {
|
20 |
+
"Cleveland Cavaliers": "Clevelan-Cavaliers-logo-2022.png",
|
21 |
+
"Atlanta Hawks": "nba-atlanta-hawks-logo.png",
|
22 |
+
"Boston Celtics": "nba-boston-celtics-logo.png",
|
23 |
+
"Brooklyn Nets": "nba-brooklyn-nets-logo.png",
|
24 |
+
"Charlotte Hornets": "nba-charlotte-hornets-logo.png",
|
25 |
+
"Chicago Bulls": "nba-chicago-bulls-logo.png",
|
26 |
+
"Dallas Mavericks": "nba-dallas-mavericks-logo.png",
|
27 |
+
"Denver Nuggets": "nba-denver-nuggets-logo-2018.png",
|
28 |
+
"Detroit Pistons": "nba-detroit-pistons-logo.png",
|
29 |
+
"Golden State Warriors": "nba-golden-state-warriors-logo-2020.png",
|
30 |
+
"Houston Rockets": "nba-houston-rockets-logo-2020.png",
|
31 |
+
"Indiana Pacers": "nba-indiana-pacers-logo.png",
|
32 |
+
"LA Clippers": "nba-la-clippers-logo.png",
|
33 |
+
"Los Angeles Lakers": "nba-los-angeles-lakers-logo.png",
|
34 |
+
"Memphis Grizzlies": "nba-memphis-grizzlies-logo.png",
|
35 |
+
"Miami Heat": "nba-miami-heat-logo.png",
|
36 |
+
"Milwaukee Bucks": "nba-milwaukee-bucks-logo.png",
|
37 |
+
"Minnesota Timberwolves": "nba-minnesota-timberwolves-logo.png",
|
38 |
+
"New Orleans Pelicans": "nba-new-orleans-pelicans-logo.png",
|
39 |
+
"New York Knicks": "nba-new-york-knicks-logo.png",
|
40 |
+
"Oklahoma City Thunder": "nba-oklahoma-city-thunder-logo.png",
|
41 |
+
"Orlando Magic": "nba-orlando-magic-logo.png",
|
42 |
+
"Philadelphia 76ers": "nba-philadelphia-76ers-logo.png",
|
43 |
+
"Phoenix Suns": "nba-phoenix-suns-logo.png",
|
44 |
+
"Portland Trail Blazers": "nba-portland-trail-blazers-logo.png",
|
45 |
+
"Sacramento Kings": "nba-sacramento-kings-logo.png",
|
46 |
+
"San Antonio Spurs": "nba-san-antonio-spurs-logo.png",
|
47 |
+
"Toronto Raptors": "nba-toronto-raptors-logo-2020.png",
|
48 |
+
"Utah Jazz": "nba-utah-jazz-logo.png",
|
49 |
+
"Washington Wizards": "nba-washington-wizards-logo.png",
|
50 |
+
}
|
51 |
+
|
52 |
|
53 |
# Mapping for position to numeric values
|
54 |
position_mapping = {
|
|
|
118 |
}
|
119 |
|
120 |
|
|
|
|
|
|
|
121 |
# Load player dataset
|
122 |
@st.cache_resource
|
123 |
def load_player_data():
|
124 |
+
return pd.read_csv("player_data.csv")
|
125 |
|
126 |
# Load Random Forest model
|
127 |
@st.cache_resource
|
128 |
def load_rf_model():
|
129 |
+
return joblib.load("rf_injury_change_model.pkl")
|
130 |
|
131 |
# Main Streamlit app
|
132 |
def main():
|
|
|
152 |
if player_name:
|
153 |
# Retrieve player details
|
154 |
player_row = player_data[player_data['player_name'] == player_name]
|
155 |
+
team_name = player_row.iloc[0]['team_abbreviation']
|
156 |
+
position = player_row.iloc[0]['position']
|
157 |
|
158 |
if not player_row.empty:
|
159 |
position = player_row.iloc[0]['position']
|
|
|
222 |
else:
|
223 |
st.sidebar.error("Please select a player to view details.")
|
224 |
|
225 |
+
st.divider()
|
226 |
+
st.header("Player Overview")
|
227 |
+
col1, col2 = st.columns([1, 2])
|
228 |
+
|
229 |
+
with col1:
|
230 |
+
st.subheader("Player Details")
|
231 |
+
st.metric("Age", default_stats['age'])
|
232 |
+
st.metric("Height (cm)", default_stats['player_height'])
|
233 |
+
st.metric("Weight (kg)", default_stats['player_weight'])
|
234 |
+
|
235 |
+
with col2:
|
236 |
+
# Display team logo
|
237 |
+
if team_name in team_logo_paths:
|
238 |
+
logo_path = team_logo_paths[team_name]
|
239 |
+
try:
|
240 |
+
logo_image = Image.open(logo_path)
|
241 |
+
st.image(logo_image, caption=f"{team_name} Logo", use_container_width=True)
|
242 |
+
except FileNotFoundError:
|
243 |
+
st.error(f"Logo for {team_name} not found.")
|
244 |
+
|
245 |
+
|
246 |
+
# Graphs for PPG, AST, and REB
|
247 |
+
st.divider()
|
248 |
+
st.header("Player Performance Graphs")
|
249 |
+
|
250 |
+
if st.button("Show Performance Graphs"):
|
251 |
+
# Filter data for the selected player
|
252 |
+
player_data_filtered = player_data[player_data["player_name"] == player_name].sort_values(by="season")
|
253 |
+
|
254 |
+
# Ensure all seasons are included
|
255 |
+
all_seasons = pd.Series(range(player_data["season"].min(), player_data["season"].max() + 1))
|
256 |
+
player_data_filtered = (
|
257 |
+
pd.DataFrame({"season": all_seasons})
|
258 |
+
.merge(player_data_filtered, on="season", how="left")
|
259 |
+
.fillna({"pts": 0, "ast": 0, "reb": 0}) # Fill missing values
|
260 |
+
)
|
261 |
+
|
262 |
+
if not player_data_filtered.empty:
|
263 |
+
# PPG Graph
|
264 |
+
fig_ppg = px.line(
|
265 |
+
player_data_filtered,
|
266 |
+
x="season",
|
267 |
+
y="pts",
|
268 |
+
title=f"{player_name}: Points Per Game (PPG) Over Seasons",
|
269 |
+
labels={"pts": "Points Per Game (PPG)", "season": "Season"},
|
270 |
+
markers=True
|
271 |
+
)
|
272 |
+
fig_ppg.update_layout(template="plotly_white")
|
273 |
+
|
274 |
+
# AST Graph
|
275 |
+
fig_ast = px.line(
|
276 |
+
player_data_filtered,
|
277 |
+
x="season",
|
278 |
+
y="ast",
|
279 |
+
title=f"{player_name}: Assists Per Game (AST) Over Seasons",
|
280 |
+
labels={"ast": "Assists Per Game (AST)", "season": "Season"},
|
281 |
+
markers=True
|
282 |
+
)
|
283 |
+
fig_ast.update_layout(template="plotly_white")
|
284 |
+
|
285 |
+
# REB Graph
|
286 |
+
fig_reb = px.line(
|
287 |
+
player_data_filtered,
|
288 |
+
x="season",
|
289 |
+
y="reb",
|
290 |
+
title=f"{player_name}: Rebounds Per Game (REB) Over Seasons",
|
291 |
+
labels={"reb": "Rebounds Per Game (REB)", "season": "Season"},
|
292 |
+
markers=True
|
293 |
+
)
|
294 |
+
fig_reb.update_layout(template="plotly_white")
|
295 |
+
|
296 |
+
# Display graphs
|
297 |
+
st.plotly_chart(fig_ppg, use_container_width=True)
|
298 |
+
st.plotly_chart(fig_ast, use_container_width=True)
|
299 |
+
st.plotly_chart(fig_reb, use_container_width=True)
|
300 |
+
else:
|
301 |
+
st.error("No data available for the selected player.")
|
302 |
+
|
303 |
+
# Footer
|
304 |
+
st.divider()
|
305 |
+
st.markdown("""
|
306 |
+
### About This Tool
|
307 |
+
This application predicts how injuries might impact an NBA player's performance using machine learning models. Data is based on historical player stats and injuries.
|
308 |
+
""")
|
309 |
+
|
310 |
if __name__ == "__main__":
|
311 |
main()
|