lrschuman17 commited on
Commit
c09e91b
·
verified ·
1 Parent(s): 99c7136

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -37
app.py CHANGED
@@ -2,8 +2,16 @@ import streamlit as st
2
  import pandas as pd
3
  import joblib
4
  from sklearn.ensemble import RandomForestRegressor
5
- import plotly.express as px
6
  import plotly.graph_objects as go
 
 
 
 
 
 
 
 
 
7
 
8
  # Mapping for position to numeric values
9
  position_mapping = {
@@ -14,7 +22,7 @@ position_mapping = {
14
  "C": 5.0,
15
  }
16
 
17
- # Predefined injury types
18
  injury_types = [
19
  "foot fracture injury", "hip flexor surgery injury", "calf strain injury",
20
  "quad injury injury", "shoulder sprain injury", "foot sprain injury",
@@ -28,10 +36,66 @@ injury_types = [
28
 
29
  average_days_injured = {
30
  "foot fracture injury": 207.666667,
31
- # ... (other injuries with their averages)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  "lower back spasm injury": 234.000000,
33
  }
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  @st.cache_resource
36
  def load_player_data():
37
  return pd.read_csv("player_data.csv")
@@ -40,83 +104,151 @@ def load_player_data():
40
  def load_rf_model():
41
  return joblib.load("rf_injury_change_model.pkl")
42
 
43
- # Main Streamlit app
 
44
  def main():
45
  st.title("NBA Player Performance Predictor 🏀")
46
- st.write("""
47
  Use this tool to predict how a player's performance metrics might change
48
  if they experience a hypothetical injury.
49
  """)
50
 
 
51
  player_data = load_player_data()
52
  rf_model = load_rf_model()
53
 
54
- # Layout: Sidebar for Inputs
55
  with st.sidebar:
56
  st.header("Player & Injury Inputs")
57
  player_list = sorted(player_data['player_name'].dropna().unique())
58
  player_name = st.selectbox("Select Player", player_list)
59
 
60
- # Player Details
61
  if player_name:
 
62
  player_row = player_data[player_data['player_name'] == player_name]
 
63
  position = player_row.iloc[0]['position']
64
- position_numeric = position_mapping.get(position, 0)
65
- st.write(f"**Position**: {position} (Numeric: {position_numeric})")
66
-
67
  stats_columns = ['age', 'player_height', 'player_weight']
68
- default_stats = {stat: player_row.iloc[0][stat] for stat in stats_columns}
69
 
 
 
 
 
 
70
  for stat in default_stats.keys():
71
  default_stats[stat] = st.number_input(f"{stat}", value=default_stats[stat])
72
 
 
73
  injury_type = st.selectbox("Select Hypothetical Injury", injury_types)
74
  default_days_injured = average_days_injured.get(injury_type, 30)
75
  days_injured = st.slider("Estimated Days Injured", 0, 365, int(default_days_injured))
76
  injury_occurrences = st.number_input("Injury Occurrences", min_value=0, value=1)
77
 
 
78
  input_data = pd.DataFrame([{
79
  "days_injured": days_injured,
80
  "injury_occurrences": injury_occurrences,
81
- "position": position_numeric,
82
  "injury_type": injury_type,
83
  **default_stats
84
  }])
85
  input_data["injury_type"] = pd.factorize(input_data["injury_type"])[0]
86
 
87
- # Main Layout
88
- col1, col2 = st.columns(2)
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  with col1:
91
  st.subheader("Player Details")
92
  st.metric("Age", default_stats['age'])
93
  st.metric("Height (cm)", default_stats['player_height'])
94
  st.metric("Weight (kg)", default_stats['player_weight'])
95
- st.image("player_placeholder.png", caption=f"{player_name}", width=200) # Placeholder image
96
 
97
  with col2:
98
- st.subheader("Prediction Results")
99
-
100
- if st.button("Predict"):
101
- predictions = rf_model.predict(input_data)
102
- predictions = [round(float(pred), 2) for pred in predictions]
103
-
104
- prediction_columns = ["Predicted Change in PTS", "Predicted Change in REB", "Predicted Change in AST"]
105
- result_df = pd.DataFrame([predictions], columns=prediction_columns)
106
-
107
- # Bar Chart for Visualizing Predictions
108
- fig = go.Figure(data=[
109
- go.Bar(
110
- x=prediction_columns,
111
- y=predictions,
112
- marker_color=["green" if val > 0 else "red" for val in predictions]
113
- )
114
- ])
115
- fig.update_layout(title="Predicted Performance Changes", xaxis_title="Metrics", yaxis_title="Change")
116
- st.plotly_chart(fig)
117
-
118
- # Display predictions as a table
119
- st.table(result_df)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  if __name__ == "__main__":
122
  main()
 
2
  import pandas as pd
3
  import joblib
4
  from sklearn.ensemble import RandomForestRegressor
 
5
  import plotly.graph_objects as go
6
+ from PIL import Image
7
+ import plotly.express as px
8
+
9
+ # Set the page configuration
10
+ st.set_page_config(
11
+ page_title="NBA Player Performance Predictor",
12
+ page_icon="🏀",
13
+ layout="centered"
14
+ )
15
 
16
  # Mapping for position to numeric values
17
  position_mapping = {
 
22
  "C": 5.0,
23
  }
24
 
25
+ # Injury types and average days
26
  injury_types = [
27
  "foot fracture injury", "hip flexor surgery injury", "calf strain injury",
28
  "quad injury injury", "shoulder sprain injury", "foot sprain injury",
 
36
 
37
  average_days_injured = {
38
  "foot fracture injury": 207.666667,
39
+ "hip flexor surgery injury": 256.000000,
40
+ "calf strain injury": 236.000000,
41
+ "quad injury injury": 283.000000,
42
+ "shoulder sprain injury": 259.500000,
43
+ "foot sprain injury": 294.000000,
44
+ "torn rotator cuff injury injury": 251.500000,
45
+ "torn mcl injury": 271.000000,
46
+ "hip flexor strain injury": 253.000000,
47
+ "fractured leg injury": 250.250000,
48
+ "sprained mcl injury": 228.666667,
49
+ "ankle sprain injury": 231.333333,
50
+ "hamstring injury injury": 220.000000,
51
+ "meniscus tear injury": 201.250000,
52
+ "torn hamstring injury": 187.666667,
53
+ "dislocated shoulder injury": 269.000000,
54
+ "ankle fracture injury": 114.500000,
55
+ "fractured hand injury": 169.142857,
56
+ "bone spurs injury": 151.500000,
57
+ "acl tear injury": 268.000000,
58
+ "hip labrum injury": 247.500000,
59
+ "back surgery injury": 215.800000,
60
+ "arm injury injury": 303.666667,
61
+ "torn shoulder labrum injury": 195.666667,
62
  "lower back spasm injury": 234.000000,
63
  }
64
 
65
+ team_logo_paths = {
66
+ "Cleveland Cavaliers": "NBA_LOGOs/Clevelan-Cavaliers-logo-2022.png",
67
+ "Atlanta Hawks": "NBA_LOGOs/nba-atlanta-hawks-logo.png",
68
+ "Boston Celtics": "NBA_LOGOs/nba-boston-celtics-logo.png",
69
+ "Brooklyn Nets": "NBA_LOGOs/nba-brooklyn-nets-logo.png",
70
+ "Charlotte Hornets": "NBA_LOGOs/nba-charlotte-hornets-logo.png",
71
+ "Chicago Bulls": "NBA_LOGOs/nba-chicago-bulls-logo.png",
72
+ "Dallas Mavericks": "NBA_LOGOs/nba-dallas-mavericks-logo.png",
73
+ "Denver Nuggets": "NBA_LOGOs/nba-denver-nuggets-logo-2018.png",
74
+ "Detroit Pistons": "NBA_LOGOs/nba-detroit-pistons-logo.png",
75
+ "Golden State Warriors": "NBA_LOGOs/nba-golden-state-warriors-logo-2020.png",
76
+ "Houston Rockets": "NBA_LOGOs/nba-houston-rockets-logo-2020.png",
77
+ "Indiana Pacers": "NBA_LOGOs/nba-indiana-pacers-logo.png",
78
+ "LA Clippers": "NBA_LOGOs/nba-la-clippers-logo.png",
79
+ "Los Angeles Lakers": "NBA_LOGOs/nba-los-angeles-lakers-logo.png",
80
+ "Memphis Grizzlies": "NBA_LOGOs/nba-memphis-grizzlies-logo.png",
81
+ "Miami Heat": "NBA_LOGOs/nba-miami-heat-logo.png",
82
+ "Milwaukee Bucks": "NBA_LOGOs/nba-milwaukee-bucks-logo.png",
83
+ "Minnesota Timberwolves": "NBA_LOGOs/nba-minnesota-timberwolves-logo.png",
84
+ "New Orleans Pelicans": "NBA_LOGOs/nba-new-orleans-pelicans-logo.png",
85
+ "New York Knicks": "NBA_LOGOs/nba-new-york-knicks-logo.png",
86
+ "Oklahoma City Thunder": "NBA_LOGOs/nba-oklahoma-city-thunder-logo.png",
87
+ "Orlando Magic": "NBA_LOGOs/nba-orlando-magic-logo.png",
88
+ "Philadelphia 76ers": "NBA_LOGOs/nba-philadelphia-76ers-logo.png",
89
+ "Phoenix Suns": "NBA_LOGOs/nba-phoenix-suns-logo.png",
90
+ "Portland Trail Blazers": "NBA_LOGOs/nba-portland-trail-blazers-logo.png",
91
+ "Sacramento Kings": "NBA_LOGOs/nba-sacramento-kings-logo.png",
92
+ "San Antonio Spurs": "NBA_LOGOs/nba-san-antonio-spurs-logo.png",
93
+ "Toronto Raptors": "NBA_LOGOs/nba-toronto-raptors-logo-2020.png",
94
+ "Utah Jazz": "NBA_LOGOs/nba-utah-jazz-logo.png",
95
+ "Washington Wizards": "NBA_LOGOs/nba-washington-wizards-logo.png",
96
+ }
97
+
98
+ # Caching player data and model
99
  @st.cache_resource
100
  def load_player_data():
101
  return pd.read_csv("player_data.csv")
 
104
  def load_rf_model():
105
  return joblib.load("rf_injury_change_model.pkl")
106
 
107
+ # Main application
108
+ # Main application
109
  def main():
110
  st.title("NBA Player Performance Predictor 🏀")
111
+ st.markdown("""
112
  Use this tool to predict how a player's performance metrics might change
113
  if they experience a hypothetical injury.
114
  """)
115
 
116
+ # Load data and model
117
  player_data = load_player_data()
118
  rf_model = load_rf_model()
119
 
120
+ # Sidebar inputs
121
  with st.sidebar:
122
  st.header("Player & Injury Inputs")
123
  player_list = sorted(player_data['player_name'].dropna().unique())
124
  player_name = st.selectbox("Select Player", player_list)
125
 
 
126
  if player_name:
127
+ # Filter data for the selected player
128
  player_row = player_data[player_data['player_name'] == player_name]
129
+ team_name = player_row.iloc[0]['team_abbreviation']
130
  position = player_row.iloc[0]['position']
 
 
 
131
  stats_columns = ['age', 'player_height', 'player_weight']
 
132
 
133
+ st.write(f"**Position**: {position}")
134
+ st.write(f"**Team**: {team_name}")
135
+
136
+ # Player stats
137
+ default_stats = {stat: player_row.iloc[0][stat] for stat in stats_columns}
138
  for stat in default_stats.keys():
139
  default_stats[stat] = st.number_input(f"{stat}", value=default_stats[stat])
140
 
141
+ # Injury details
142
  injury_type = st.selectbox("Select Hypothetical Injury", injury_types)
143
  default_days_injured = average_days_injured.get(injury_type, 30)
144
  days_injured = st.slider("Estimated Days Injured", 0, 365, int(default_days_injured))
145
  injury_occurrences = st.number_input("Injury Occurrences", min_value=0, value=1)
146
 
147
+ # Prepare data for prediction
148
  input_data = pd.DataFrame([{
149
  "days_injured": days_injured,
150
  "injury_occurrences": injury_occurrences,
151
+ "position": position_mapping.get(position, 0),
152
  "injury_type": injury_type,
153
  **default_stats
154
  }])
155
  input_data["injury_type"] = pd.factorize(input_data["injury_type"])[0]
156
 
157
+ st.header("Prediction Results")
158
+ if st.button("Predict"):
159
+ predictions = rf_model.predict(input_data)
160
+ predictions = [round(float(pred), 2) for pred in predictions]
161
+
162
+ # Display prediction results
163
+ prediction_columns = ["Predicted Change in PTS", "Predicted Change in REB", "Predicted Change in AST"]
164
+ result_df = pd.DataFrame([predictions], columns=prediction_columns)
165
+ st.table(result_df)
166
+
167
+ # Main content layout
168
+ st.divider()
169
+ st.header("Player Overview")
170
+ col1, col2 = st.columns([1, 2])
171
 
172
  with col1:
173
  st.subheader("Player Details")
174
  st.metric("Age", default_stats['age'])
175
  st.metric("Height (cm)", default_stats['player_height'])
176
  st.metric("Weight (kg)", default_stats['player_weight'])
 
177
 
178
  with col2:
179
+ # Display team logo
180
+ if team_name in team_logo_paths:
181
+ logo_path = team_logo_paths[team_name]
182
+ try:
183
+ logo_image = Image.open(logo_path)
184
+ st.image(logo_image, caption=f"{team_name} Logo", use_column_width=True)
185
+ except FileNotFoundError:
186
+ st.error(f"Logo for {team_name} not found.")
187
+
188
+
189
+ # Graphs for PPG, AST, and REB
190
+ st.divider()
191
+ st.header("Player Performance Graphs")
192
+
193
+ if st.button("Show Performance Graphs"):
194
+ # Filter data for the selected player
195
+ player_data_filtered = player_data[player_data["player_name"] == player_name].sort_values(by="season")
196
+
197
+ # Ensure all seasons are included
198
+ all_seasons = pd.Series(range(player_data["season"].min(), player_data["season"].max() + 1))
199
+ player_data_filtered = (
200
+ pd.DataFrame({"season": all_seasons})
201
+ .merge(player_data_filtered, on="season", how="left")
202
+ .fillna({"pts": 0, "ast": 0, "reb": 0}) # Fill missing values
203
+ )
204
+
205
+ if not player_data_filtered.empty:
206
+ # PPG Graph
207
+ fig_ppg = px.line(
208
+ player_data_filtered,
209
+ x="season",
210
+ y="pts",
211
+ title=f"{player_name}: Points Per Game (PPG) Over Seasons",
212
+ labels={"pts": "Points Per Game (PPG)", "season": "Season"},
213
+ markers=True
214
+ )
215
+ fig_ppg.update_layout(template="plotly_white")
216
+
217
+ # AST Graph
218
+ fig_ast = px.line(
219
+ player_data_filtered,
220
+ x="season",
221
+ y="ast",
222
+ title=f"{player_name}: Assists Per Game (AST) Over Seasons",
223
+ labels={"ast": "Assists Per Game (AST)", "season": "Season"},
224
+ markers=True
225
+ )
226
+ fig_ast.update_layout(template="plotly_white")
227
+
228
+ # REB Graph
229
+ fig_reb = px.line(
230
+ player_data_filtered,
231
+ x="season",
232
+ y="reb",
233
+ title=f"{player_name}: Rebounds Per Game (REB) Over Seasons",
234
+ labels={"reb": "Rebounds Per Game (REB)", "season": "Season"},
235
+ markers=True
236
+ )
237
+ fig_reb.update_layout(template="plotly_white")
238
+
239
+ # Display graphs
240
+ st.plotly_chart(fig_ppg, use_container_width=True)
241
+ st.plotly_chart(fig_ast, use_container_width=True)
242
+ st.plotly_chart(fig_reb, use_container_width=True)
243
+ else:
244
+ st.error("No data available for the selected player.")
245
+
246
+ # Footer
247
+ st.divider()
248
+ st.markdown("""
249
+ ### About This Tool
250
+ This application predicts how injuries might impact an NBA player's performance using machine learning models. Data is based on historical player stats and injuries.
251
+ """)
252
 
253
  if __name__ == "__main__":
254
  main()