lrschuman17 commited on
Commit
a2f85fb
Β·
verified Β·
1 Parent(s): 6258529

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -134
app.py CHANGED
@@ -1,14 +1,12 @@
1
  import streamlit as st
2
  import pandas as pd
3
  import joblib
4
- from sklearn.ensemble import RandomForestRegressor
5
- import plotly.graph_objects as go
6
- from PIL import Image
7
  import plotly.express as px
 
8
 
9
  # Set the page configuration
10
  st.set_page_config(
11
- page_title="NBA Player Performance Predictor",
12
  page_icon="πŸ€",
13
  layout="centered"
14
  )
@@ -63,36 +61,36 @@ average_days_injured = {
63
  }
64
 
65
  team_logo_paths = {
66
- "Cleveland Cavaliers": "Clevelan-Cavaliers-logo-2022.png",
67
- "Atlanta Hawks": "nba-atlanta-hawks-logo.png",
68
- "Boston Celtics": "nba-boston-celtics-logo.png",
69
- "Brooklyn Nets": "nba-brooklyn-nets-logo.png",
70
- "Charlotte Hornets": "nba-charlotte-hornets-logo.png",
71
- "Chicago Bulls": "nba-chicago-bulls-logo.png",
72
- "Dallas Mavericks": "nba-dallas-mavericks-logo.png",
73
- "Denver Nuggets": "nba-denver-nuggets-logo-2018.png",
74
- "Detroit Pistons": "nba-detroit-pistons-logo.png",
75
- "Golden State Warriors": "nba-golden-state-warriors-logo-2020.png",
76
- "Houston Rockets": "nba-houston-rockets-logo-2020.png",
77
- "Indiana Pacers": "nba-indiana-pacers-logo.png",
78
- "LA Clippers": "nba-la-clippers-logo.png",
79
- "Los Angeles Lakers": "nba-los-angeles-lakers-logo.png",
80
- "Memphis Grizzlies": "nba-memphis-grizzlies-logo.png",
81
- "Miami Heat": "nba-miami-heat-logo.png",
82
- "Milwaukee Bucks": "nba-milwaukee-bucks-logo.png",
83
- "Minnesota Timberwolves": "nba-minnesota-timberwolves-logo.png",
84
- "New Orleans Pelicans": "nba-new-orleans-pelicans-logo.png",
85
- "New York Knicks": "nba-new-york-knicks-logo.png",
86
- "Oklahoma City Thunder": "nba-oklahoma-city-thunder-logo.png",
87
- "Orlando Magic": "nba-orlando-magic-logo.png",
88
- "Philadelphia 76ers": "nba-philadelphia-76ers-logo.png",
89
- "Phoenix Suns": "nba-phoenix-suns-logo.png",
90
- "Portland Trail Blazers": "nba-portland-trail-blazers-logo.png",
91
- "Sacramento Kings": "nba-sacramento-kings-logo.png",
92
- "San Antonio Spurs": "nba-san-antonio-spurs-logo.png",
93
- "Toronto Raptors": "nba-toronto-raptors-logo-2020.png",
94
- "Utah Jazz": "nba-utah-jazz-logo.png",
95
- "Washington Wizards": "nba-washington-wizards-logo.png",
96
  }
97
 
98
  # Caching player data and model
@@ -104,14 +102,15 @@ def load_player_data():
104
  def load_rf_model():
105
  return joblib.load("rf_injury_change_model.pkl")
106
 
107
- # Main application
108
  # Main application
109
  def main():
110
  st.title("NBA Player Performance Predictor πŸ€")
111
- st.markdown("""
112
- Use this tool to predict how a player's performance metrics might change
113
- if they experience a hypothetical injury.
114
- """)
 
 
115
 
116
  # Load data and model
117
  player_data = load_player_data()
@@ -152,103 +151,32 @@ def main():
152
  "injury_type": injury_type,
153
  **default_stats
154
  }])
155
- input_data["injury_type"] = pd.factorize(input_data["injury_type"])[0]
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  st.header("Prediction Results")
158
- if st.button("Predict"):
159
- predictions = rf_model.predict(input_data)
160
- predictions = [round(float(pred), 2) for pred in predictions]
161
-
162
- # Display prediction results
163
- prediction_columns = ["Predicted Change in PTS", "Predicted Change in REB", "Predicted Change in AST"]
164
- result_df = pd.DataFrame([predictions], columns=prediction_columns)
165
- st.table(result_df)
166
-
167
- # Main content layout
168
- st.divider()
169
- st.header("Player Overview")
170
- col1, col2 = st.columns([1, 2])
171
-
172
- with col1:
173
- st.subheader("Player Details")
174
- st.metric("Age", default_stats['age'])
175
- st.metric("Height (cm)", default_stats['player_height'])
176
- st.metric("Weight (kg)", default_stats['player_weight'])
177
-
178
- with col2:
179
- # Display team logo
180
- if team_name in team_logo_paths:
181
- logo_path = team_logo_paths[team_name]
182
- try:
183
- logo_image = Image.open(logo_path)
184
- st.image(logo_image, caption=f"{team_name} Logo", use_container_width=True)
185
- except FileNotFoundError:
186
- st.error(f"Logo for {team_name} not found.")
187
-
188
-
189
- # Graphs for PPG, AST, and REB
190
- st.divider()
191
- st.header("Player Performance Graphs")
192
-
193
- if st.button("Show Performance Graphs"):
194
- # Filter data for the selected player
195
- player_data_filtered = player_data[player_data["player_name"] == player_name].sort_values(by="season")
196
-
197
- # Ensure all seasons are included
198
- all_seasons = pd.Series(range(player_data["season"].min(), player_data["season"].max() + 1))
199
- player_data_filtered = (
200
- pd.DataFrame({"season": all_seasons})
201
- .merge(player_data_filtered, on="season", how="left")
202
- .fillna({"pts": 0, "ast": 0, "reb": 0}) # Fill missing values
203
- )
204
-
205
- if not player_data_filtered.empty:
206
- # PPG Graph
207
- fig_ppg = px.line(
208
- player_data_filtered,
209
- x="season",
210
- y="pts",
211
- title=f"{player_name}: Points Per Game (PPG) Over Seasons",
212
- labels={"pts": "Points Per Game (PPG)", "season": "Season"},
213
- markers=True
214
- )
215
- fig_ppg.update_layout(template="plotly_white")
216
-
217
- # AST Graph
218
- fig_ast = px.line(
219
- player_data_filtered,
220
- x="season",
221
- y="ast",
222
- title=f"{player_name}: Assists Per Game (AST) Over Seasons",
223
- labels={"ast": "Assists Per Game (AST)", "season": "Season"},
224
- markers=True
225
- )
226
- fig_ast.update_layout(template="plotly_white")
227
-
228
- # REB Graph
229
- fig_reb = px.line(
230
- player_data_filtered,
231
- x="season",
232
- y="reb",
233
- title=f"{player_name}: Rebounds Per Game (REB) Over Seasons",
234
- labels={"reb": "Rebounds Per Game (REB)", "season": "Season"},
235
- markers=True
236
- )
237
- fig_reb.update_layout(template="plotly_white")
238
-
239
- # Display graphs
240
- st.plotly_chart(fig_ppg, use_container_width=True)
241
- st.plotly_chart(fig_ast, use_container_width=True)
242
- st.plotly_chart(fig_reb, use_container_width=True)
243
- else:
244
- st.error("No data available for the selected player.")
245
-
246
- # Footer
247
- st.divider()
248
- st.markdown("""
249
- ### About This Tool
250
- This application predicts how injuries might impact an NBA player's performance using machine learning models. Data is based on historical player stats and injuries.
251
- """)
252
 
253
  if __name__ == "__main__":
254
  main()
 
1
  import streamlit as st
2
  import pandas as pd
3
  import joblib
 
 
 
4
  import plotly.express as px
5
+ from PIL import Image
6
 
7
  # Set the page configuration
8
  st.set_page_config(
9
+ page_title="NBA Player Performance Predictor πŸ€",
10
  page_icon="πŸ€",
11
  layout="centered"
12
  )
 
61
  }
62
 
63
  team_logo_paths = {
64
+ "Cleveland Cavaliers": "NBA_LOGOs/Clevelan-Cavaliers-logo-2022.png",
65
+ "Atlanta Hawks": "NBA_LOGOs/nba-atlanta-hawks-logo.png",
66
+ "Boston Celtics": "NBA_LOGOs/nba-boston-celtics-logo.png",
67
+ "Brooklyn Nets": "NBA_LOGOs/nba-brooklyn-nets-logo.png",
68
+ "Charlotte Hornets": "NBA_LOGOs/nba-charlotte-hornets-logo.png",
69
+ "Chicago Bulls": "NBA_LOGOs/nba-chicago-bulls-logo.png",
70
+ "Dallas Mavericks": "NBA_LOGOs/nba-dallas-mavericks-logo.png",
71
+ "Denver Nuggets": "NBA_LOGOs/nba-denver-nuggets-logo-2018.png",
72
+ "Detroit Pistons": "NBA_LOGOs/nba-detroit-pistons-logo.png",
73
+ "Golden State Warriors": "NBA_LOGOs/nba-golden-state-warriors-logo-2020.png",
74
+ "Houston Rockets": "NBA_LOGOs/nba-houston-rockets-logo-2020.png",
75
+ "Indiana Pacers": "NBA_LOGOs/nba-indiana-pacers-logo.png",
76
+ "LA Clippers": "NBA_LOGOs/nba-la-clippers-logo.png",
77
+ "Los Angeles Lakers": "NBA_LOGOs/nba-los-angeles-lakers-logo.png",
78
+ "Memphis Grizzlies": "NBA_LOGOs/nba-memphis-grizzlies-logo.png",
79
+ "Miami Heat": "NBA_LOGOs/nba-miami-heat-logo.png",
80
+ "Milwaukee Bucks": "NBA_LOGOs/nba-milwaukee-bucks-logo.png",
81
+ "Minnesota Timberwolves": "NBA_LOGOs/nba-minnesota-timberwolves-logo.png",
82
+ "New Orleans Pelicans": "NBA_LOGOs/nba-new-orleans-pelicans-logo.png",
83
+ "New York Knicks": "NBA_LOGOs/nba-new-york-knicks-logo.png",
84
+ "Oklahoma City Thunder": "NBA_LOGOs/nba-oklahoma-city-thunder-logo.png",
85
+ "Orlando Magic": "NBA_LOGOs/nba-orlando-magic-logo.png",
86
+ "Philadelphia 76ers": "NBA_LOGOs/nba-philadelphia-76ers-logo.png",
87
+ "Phoenix Suns": "NBA_LOGOs/nba-phoenix-suns-logo.png",
88
+ "Portland Trail Blazers": "NBA_LOGOs/nba-portland-trail-blazers-logo.png",
89
+ "Sacramento Kings": "NBA_LOGOs/nba-sacramento-kings-logo.png",
90
+ "San Antonio Spurs": "NBA_LOGOs/nba-san-antonio-spurs-logo.png",
91
+ "Toronto Raptors": "NBA_LOGOs/nba-toronto-raptors-logo-2020.png",
92
+ "Utah Jazz": "NBA_LOGOs/nba-utah-jazz-logo.png",
93
+ "Washington Wizards": "NBA_LOGOs/nba-washington-wizards-logo.png",
94
  }
95
 
96
  # Caching player data and model
 
102
  def load_rf_model():
103
  return joblib.load("rf_injury_change_model.pkl")
104
 
 
105
  # Main application
106
  def main():
107
  st.title("NBA Player Performance Predictor πŸ€")
108
+ st.markdown(
109
+ """
110
+ Use this tool to predict how a player's performance metrics might change
111
+ if they experience a hypothetical injury.
112
+ """
113
+ )
114
 
115
  # Load data and model
116
  player_data = load_player_data()
 
151
  "injury_type": injury_type,
152
  **default_stats
153
  }])
 
154
 
155
+ # One-hot encode injury type
156
+ input_data = pd.get_dummies(input_data, columns=["injury_type"], drop_first=True)
157
+
158
+ # Align with model's feature names
159
+ expected_features = rf_model.feature_names_in_
160
+ for feature in expected_features:
161
+ if feature not in input_data.columns:
162
+ input_data[feature] = 0
163
+
164
+ # Ensure columns are in the same order as the model's feature names
165
+ input_data = input_data[expected_features]
166
+
167
+ # Predict and display results
168
  st.header("Prediction Results")
169
+ if st.sidebar.button("Predict"):
170
+ try:
171
+ predictions = rf_model.predict(input_data)
172
+ prediction_columns = ["Predicted Change in PTS", "Predicted Change in REB", "Predicted Change in AST"]
173
+ st.subheader("Predicted Post-Injury Performance")
174
+ st.write("Based on the inputs, here are the predicted metrics:")
175
+ st.table(pd.DataFrame(predictions, columns=prediction_columns))
176
+ except FileNotFoundError:
177
+ st.error("Model file not found.")
178
+ except ValueError as e:
179
+ st.error(f"Error during prediction: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  if __name__ == "__main__":
182
  main()