lrschuman17 commited on
Commit
559d270
·
verified ·
1 Parent(s): 7758341

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -158
app.py CHANGED
@@ -15,6 +15,86 @@ st.set_page_config(
15
  layout="centered"
16
  )
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  team_logo_paths = {
20
  "Cleveland Cavaliers": "Clevelan-Cavaliers-logo-2022.png",
@@ -130,181 +210,98 @@ def load_rf_model():
130
 
131
  # Main Streamlit app
132
  def main():
133
- st.title("NBA Player Performance Predictor")
134
  st.write(
135
  """
136
- Predict how a player's performance metrics (e.g., points, rebounds, assists) might change
137
- if a hypothetical injury occurs, based on their position and other factors.
138
  """
139
  )
140
 
141
- # Load player data
142
  player_data = load_player_data()
143
  rf_model = load_rf_model()
144
 
145
  # Sidebar inputs
146
- st.sidebar.header("Player and Injury Input")
147
-
148
- # Dropdown for player selection
 
 
 
 
 
149
  player_list = sorted(player_data['player_name'].dropna().unique())
150
  player_name = st.sidebar.selectbox("Select Player", player_list)
151
 
152
  if player_name:
153
- # Retrieve player details
154
  player_row = player_data[player_data['player_name'] == player_name]
155
  team_name = player_row.iloc[0]['team_abbreviation']
156
  position = player_row.iloc[0]['position']
157
 
158
- if not player_row.empty:
159
- position = player_row.iloc[0]['position']
160
- position_numeric = position_mapping.get(position, 0)
161
-
162
- st.sidebar.write(f"**Position**: {position} (Numeric: {position_numeric})")
163
-
164
- # Default values for features
165
- stats_columns = ['age', 'player_height', 'player_weight']
166
- default_stats = {
167
- stat: player_row.iloc[0][stat] if stat in player_row.columns else 0
168
- for stat in stats_columns
169
- }
170
-
171
- # Allow manual adjustment of stats
172
- for stat in default_stats.keys():
173
- default_stats[stat] = st.sidebar.number_input(f"{stat}", value=default_stats[stat])
174
-
175
- # Injury details
176
- injury_type = st.sidebar.selectbox("Select Hypothetical Injury", injury_types)
177
- # Replace slider with default average based on injury type
178
- default_days_injured = average_days_injured[injury_type] or 30 # Use 30 if `None`
179
- days_injured = st.sidebar.slider(
180
- "Estimated Days Injured",
181
- 0,
182
- 365,
183
- int(default_days_injured),
184
- help=f"Default days for {injury_type}: {int(default_days_injured) if default_days_injured else 'N/A'}"
185
- )
186
- injury_occurrences = st.sidebar.number_input("Injury Occurrences", min_value=0, value=1)
187
-
188
- # Prepare input data
189
- input_data = pd.DataFrame([{
190
- "days_injured": days_injured,
191
- "injury_occurrences": injury_occurrences,
192
- "position": position_numeric,
193
- "injury_type": injury_type, # Include the selected injury type
194
- **default_stats
195
- }])
196
-
197
- # Encode injury type
198
- input_data["injury_type"] = pd.factorize(input_data["injury_type"])[0]
199
-
200
- # Load Random Forest model
201
- try:
202
- rf_model = load_rf_model()
203
-
204
- # Align input data with the model's feature names
205
- expected_features = rf_model.feature_names_in_
206
- input_data = input_data.reindex(columns=rf_model.feature_names_in_, fill_value=0)
207
-
208
- # Predict and display results
209
- if st.sidebar.button("Predict"):
210
- predictions = rf_model.predict(input_data)
211
- prediction_columns = ["Predicted Change in PTS", "Predicted Change in REB", "Predicted Change inAST"]
212
- st.subheader("Predicted Post-Injury Performance")
213
- st.write("Based on the inputs, here are the predicted metrics:")
214
- st.table(pd.DataFrame(predictions, columns=prediction_columns))
215
- except FileNotFoundError:
216
- st.error("Model file not found.")
217
- except ValueError as e:
218
- st.error(f"Error during prediction: {e}")
219
-
220
- else:
221
- st.sidebar.error("Player details not found in the dataset.")
222
- else:
223
- st.sidebar.error("Please select a player to view details.")
224
-
225
- st.divider()
226
- st.header("Player Overview")
227
- col1, col2 = st.columns([1, 2])
228
-
229
- with col1:
230
- st.subheader("Player Details")
231
- st.metric("Age", default_stats['age'])
232
- st.metric("Height (cm)", default_stats['player_height'])
233
- st.metric("Weight (kg)", default_stats['player_weight'])
234
-
235
- with col2:
236
- # Display team logo
237
- if team_name in team_logo_paths:
238
- logo_path = team_logo_paths[team_name]
239
- try:
240
- logo_image = Image.open(logo_path)
241
- st.image(logo_image, caption=f"{team_name} Logo", use_container_width=True)
242
- except FileNotFoundError:
243
- st.error(f"Logo for {team_name} not found.")
244
-
245
-
246
- # Graphs for PPG, AST, and REB
247
- st.divider()
248
- st.header("Player Performance Graphs")
249
-
250
- if st.button("Show Performance Graphs"):
251
- # Filter data for the selected player
252
- player_data_filtered = player_data[player_data["player_name"] == player_name].sort_values(by="season")
253
-
254
- # Ensure all seasons are included
255
- all_seasons = pd.Series(range(player_data["season"].min(), player_data["season"].max() + 1))
256
- player_data_filtered = (
257
- pd.DataFrame({"season": all_seasons})
258
- .merge(player_data_filtered, on="season", how="left")
259
- )
260
-
261
- if not player_data_filtered.empty:
262
- # PPG Graph
263
- fig_ppg = px.line(
264
- player_data_filtered,
265
- x="season",
266
- y="pts",
267
- title=f"{player_name}: Points Per Game (PPG) Over Seasons",
268
- labels={"pts": "Points Per Game (PPG)", "season": "Season"},
269
- markers=True
270
- )
271
- fig_ppg.update_layout(template="plotly_white")
272
-
273
- # AST Graph
274
- fig_ast = px.line(
275
- player_data_filtered,
276
- x="season",
277
- y="ast",
278
- title=f"{player_name}: Assists Per Game (AST) Over Seasons",
279
- labels={"ast": "Assists Per Game (AST)", "season": "Season"},
280
- markers=True
281
- )
282
- fig_ast.update_layout(template="plotly_white")
283
-
284
- # REB Graph
285
- fig_reb = px.line(
286
- player_data_filtered,
287
- x="season",
288
- y="reb",
289
- title=f"{player_name}: Rebounds Per Game (REB) Over Seasons",
290
- labels={"reb": "Rebounds Per Game (REB)", "season": "Season"},
291
- markers=True
292
- )
293
- fig_reb.update_layout(template="plotly_white")
294
-
295
- # Display graphs
296
- st.plotly_chart(fig_ppg, use_container_width=True)
297
- st.plotly_chart(fig_ast, use_container_width=True)
298
- st.plotly_chart(fig_reb, use_container_width=True)
299
- else:
300
- st.error("No data available for the selected player.")
301
-
302
- # Footer
303
- st.divider()
304
- st.markdown("""
305
- ### About This Tool
306
- This application predicts how injuries might impact an NBA player's performance using machine learning models. Data is based on historical player stats and injuries.
307
- """)
308
 
309
  if __name__ == "__main__":
310
  main()
 
15
  layout="centered"
16
  )
17
 
18
+ # Custom CSS for vibrant NBA sidebar header
19
+ st.markdown(
20
+ """
21
+ <style>
22
+ body {
23
+ background: linear-gradient(to bottom, #0033a0, #ed174c); /* NBA team colors gradient */
24
+ font-family: 'Trebuchet MS', sans-serif;
25
+ margin: 0;
26
+ padding: 0;
27
+ color: white; /* Set text color to white */
28
+ }
29
+ .sidebar .sidebar-content {
30
+ background: linear-gradient(to bottom, #4B0082, #1E90FF); /* Purple to blue gradient */
31
+ border-radius: 10px;
32
+ padding: 10px;
33
+ color: #ffffff; /* Set sidebar text color to white */
34
+ }
35
+ .sidebar h2 {
36
+ background: linear-gradient(to right, #FF1493, #FF4500); /* Pink to red gradient */
37
+ color: white; /* Text color */
38
+ padding: 10px;
39
+ text-align: center;
40
+ font-size: 20px;
41
+ font-weight: bold;
42
+ border-radius: 5px;
43
+ text-shadow: 2px 2px #000000; /* Add shadow for better visibility */
44
+ margin-bottom: 15px;
45
+ }
46
+ .stButton > button {
47
+ background-color: #ffcc00; /* Bold yellow */
48
+ color: #0033a0; /* Button text color */
49
+ border: none;
50
+ border-radius: 5px;
51
+ padding: 10px 15px;
52
+ font-size: 16px;
53
+ transition: background-color 0.3s ease;
54
+ }
55
+ .stButton > button:hover {
56
+ background-color: #ffc107; /* Brighter yellow */
57
+ box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.2);
58
+ }
59
+ .stMarkdown h1, .stMarkdown h2, .stMarkdown h3 {
60
+ color: #ffffff; /* Set headings color to white */
61
+ text-shadow: 2px 2px #000000; /* Add shadow for better visibility */
62
+ }
63
+ .block-container {
64
+ border-radius: 10px;
65
+ padding: 20px;
66
+ background-color: rgba(0, 0, 0, 0.8); /* Dark semi-transparent background */
67
+ color: #ffffff; /* Ensure text inside the container is white */
68
+ }
69
+ .dataframe {
70
+ background-color: rgba(255, 255, 255, 0.1); /* Transparent table background */
71
+ color: #ffffff; /* Table text color */
72
+ border-radius: 10px;
73
+ }
74
+ .stPlotlyChart {
75
+ background-color: rgba(0, 0, 0, 0.8); /* Match dark theme */
76
+ padding: 10px;
77
+ border-radius: 10px;
78
+ box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.5);
79
+ }
80
+ .styled-table {
81
+ width: 100%;
82
+ border-collapse: collapse;
83
+ margin: 25px 0;
84
+ font-size: 18px;
85
+ text-align: left;
86
+ border-radius: 5px 5px 0 0;
87
+ overflow: hidden;
88
+ color: #ffffff; /* Table text color */
89
+ }
90
+ .styled-table th, .styled-table td {
91
+ padding: 12px 15px;
92
+ }
93
+ </style>
94
+ """,
95
+ unsafe_allow_html=True
96
+ )
97
+
98
 
99
  team_logo_paths = {
100
  "Cleveland Cavaliers": "Clevelan-Cavaliers-logo-2022.png",
 
210
 
211
  # Main Streamlit app
212
  def main():
213
+ st.title("NBA Player Performance Predictor 🏀")
214
  st.write(
215
  """
216
+ Welcome to the **NBA Player Performance Predictor**! This app helps predict changes in a player's performance metrics
217
+ after experiencing a hypothetical injury. Simply input the details and see the magic happen!
218
  """
219
  )
220
 
221
+ # Load player data and model
222
  player_data = load_player_data()
223
  rf_model = load_rf_model()
224
 
225
  # Sidebar inputs
226
+ st.sidebar.markdown(
227
+ """
228
+ <div style="padding: 10px; background: linear-gradient(to right, #6a11cb, #2575fc); color: white; border-radius: 10px;">
229
+ <h3>Player Details</h3>
230
+ </div>
231
+ """,
232
+ unsafe_allow_html=True
233
+ )
234
  player_list = sorted(player_data['player_name'].dropna().unique())
235
  player_name = st.sidebar.selectbox("Select Player", player_list)
236
 
237
  if player_name:
 
238
  player_row = player_data[player_data['player_name'] == player_name]
239
  team_name = player_row.iloc[0]['team_abbreviation']
240
  position = player_row.iloc[0]['position']
241
 
242
+ stats_columns = ['age', 'player_height', 'player_weight']
243
+ default_stats = {stat: player_row.iloc[0][stat] for stat in stats_columns}
244
+
245
+ for stat in default_stats.keys():
246
+ default_stats[stat] = st.sidebar.number_input(f"{stat}", value=default_stats[stat])
247
+
248
+ injury_type = st.sidebar.selectbox("Select Hypothetical Injury", injury_types)
249
+ default_days_injured = average_days_injured.get(injury_type, 30)
250
+ days_injured = st.sidebar.slider("Estimated Days Injured", 0, 365, int(default_days_injured))
251
+ injury_occurrences = st.sidebar.number_input("Injury Occurrences", min_value=0, value=1)
252
+
253
+ input_data = pd.DataFrame([{
254
+ "days_injured": days_injured,
255
+ "injury_occurrences": injury_occurrences,
256
+ "position": position_mapping.get(position, 0),
257
+ "injury_type": injury_type,
258
+ **default_stats
259
+ }])
260
+ input_data["injury_type"] = pd.factorize(input_data["injury_type"])[0]
261
+
262
+ st.divider()
263
+ st.header("Player Overview")
264
+ col1, col2 = st.columns([1, 2])
265
+
266
+ with col1:
267
+ st.subheader("Player Details")
268
+ st.metric("Age", default_stats['age'])
269
+ st.metric("Height (cm)", default_stats['player_height'])
270
+ st.metric("Weight (kg)", default_stats['player_weight'])
271
+
272
+ with col2:
273
+ if team_name in team_logo_paths:
274
+ logo_path = team_logo_paths[team_name]
275
+ try:
276
+ logo_image = Image.open(logo_path)
277
+ st.image(logo_image, caption=f"{team_name} Logo", use_column_width=True)
278
+ except FileNotFoundError:
279
+ st.error(f"Logo for {team_name} not found.")
280
+
281
+ if st.sidebar.button("Predict 🔮"):
282
+ predictions = rf_model.predict(input_data)
283
+ prediction_columns = ["Predicted Change in PTS", "Predicted Change in REB", "Predicted Change in AST"]
284
+ prediction_df = pd.DataFrame(predictions, columns=prediction_columns)
285
+
286
+ st.subheader("Predicted Post-Injury Performance")
287
+ st.write(prediction_df)
288
+
289
+ fig = go.Figure()
290
+ for col in prediction_columns:
291
+ fig.add_trace(go.Bar(
292
+ x=[col],
293
+ y=prediction_df[col],
294
+ name=col,
295
+ marker=dict(color=px.colors.qualitative.Plotly[prediction_columns.index(col)])
296
+ ))
297
+
298
+ fig.update_layout(
299
+ title="Predicted Performance Changes",
300
+ xaxis_title="Metrics",
301
+ yaxis_title="Change Value",
302
+ template="plotly_dark"
303
+ )
304
+ st.plotly_chart(fig, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
  if __name__ == "__main__":
307
  main()