lrschuman17 commited on
Commit
dd5f0c1
Β·
verified Β·
1 Parent(s): a2f85fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -105
app.py CHANGED
@@ -1,37 +1,48 @@
1
  import streamlit as st
2
  import pandas as pd
3
  import joblib
 
4
  import plotly.express as px
5
- from PIL import Image
6
-
7
- # Set the page configuration
8
- st.set_page_config(
9
- page_title="NBA Player Performance Predictor πŸ€",
10
- page_icon="πŸ€",
11
- layout="centered"
12
- )
13
 
14
  # Mapping for position to numeric values
15
  position_mapping = {
16
- "PG": 1.0,
17
- "SG": 2.0,
18
- "SF": 3.0,
19
- "PF": 4.0,
20
- "C": 5.0,
21
  }
22
 
23
- # Injury types and average days
24
  injury_types = [
25
- "foot fracture injury", "hip flexor surgery injury", "calf strain injury",
26
- "quad injury injury", "shoulder sprain injury", "foot sprain injury",
27
- "torn rotator cuff injury injury", "torn mcl injury", "hip flexor strain injury",
28
- "fractured leg injury", "sprained mcl injury", "ankle sprain injury",
29
- "hamstring injury injury", "meniscus tear injury", "torn hamstring injury",
30
- "dislocated shoulder injury", "ankle fracture injury", "fractured hand injury",
31
- "bone spurs injury", "acl tear injury", "hip labrum injury", "back surgery injury",
32
- "arm injury injury", "torn shoulder labrum injury", "lower back spasm injury"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  ]
34
 
 
35
  average_days_injured = {
36
  "foot fracture injury": 207.666667,
37
  "hip flexor surgery injury": 256.000000,
@@ -60,123 +71,111 @@ average_days_injured = {
60
  "lower back spasm injury": 234.000000,
61
  }
62
 
63
- team_logo_paths = {
64
- "Cleveland Cavaliers": "NBA_LOGOs/Clevelan-Cavaliers-logo-2022.png",
65
- "Atlanta Hawks": "NBA_LOGOs/nba-atlanta-hawks-logo.png",
66
- "Boston Celtics": "NBA_LOGOs/nba-boston-celtics-logo.png",
67
- "Brooklyn Nets": "NBA_LOGOs/nba-brooklyn-nets-logo.png",
68
- "Charlotte Hornets": "NBA_LOGOs/nba-charlotte-hornets-logo.png",
69
- "Chicago Bulls": "NBA_LOGOs/nba-chicago-bulls-logo.png",
70
- "Dallas Mavericks": "NBA_LOGOs/nba-dallas-mavericks-logo.png",
71
- "Denver Nuggets": "NBA_LOGOs/nba-denver-nuggets-logo-2018.png",
72
- "Detroit Pistons": "NBA_LOGOs/nba-detroit-pistons-logo.png",
73
- "Golden State Warriors": "NBA_LOGOs/nba-golden-state-warriors-logo-2020.png",
74
- "Houston Rockets": "NBA_LOGOs/nba-houston-rockets-logo-2020.png",
75
- "Indiana Pacers": "NBA_LOGOs/nba-indiana-pacers-logo.png",
76
- "LA Clippers": "NBA_LOGOs/nba-la-clippers-logo.png",
77
- "Los Angeles Lakers": "NBA_LOGOs/nba-los-angeles-lakers-logo.png",
78
- "Memphis Grizzlies": "NBA_LOGOs/nba-memphis-grizzlies-logo.png",
79
- "Miami Heat": "NBA_LOGOs/nba-miami-heat-logo.png",
80
- "Milwaukee Bucks": "NBA_LOGOs/nba-milwaukee-bucks-logo.png",
81
- "Minnesota Timberwolves": "NBA_LOGOs/nba-minnesota-timberwolves-logo.png",
82
- "New Orleans Pelicans": "NBA_LOGOs/nba-new-orleans-pelicans-logo.png",
83
- "New York Knicks": "NBA_LOGOs/nba-new-york-knicks-logo.png",
84
- "Oklahoma City Thunder": "NBA_LOGOs/nba-oklahoma-city-thunder-logo.png",
85
- "Orlando Magic": "NBA_LOGOs/nba-orlando-magic-logo.png",
86
- "Philadelphia 76ers": "NBA_LOGOs/nba-philadelphia-76ers-logo.png",
87
- "Phoenix Suns": "NBA_LOGOs/nba-phoenix-suns-logo.png",
88
- "Portland Trail Blazers": "NBA_LOGOs/nba-portland-trail-blazers-logo.png",
89
- "Sacramento Kings": "NBA_LOGOs/nba-sacramento-kings-logo.png",
90
- "San Antonio Spurs": "NBA_LOGOs/nba-san-antonio-spurs-logo.png",
91
- "Toronto Raptors": "NBA_LOGOs/nba-toronto-raptors-logo-2020.png",
92
- "Utah Jazz": "NBA_LOGOs/nba-utah-jazz-logo.png",
93
- "Washington Wizards": "NBA_LOGOs/nba-washington-wizards-logo.png",
94
- }
95
 
96
- # Caching player data and model
 
 
 
97
  @st.cache_resource
98
  def load_player_data():
99
- return pd.read_csv("player_data.csv")
100
 
 
101
  @st.cache_resource
102
  def load_rf_model():
103
- return joblib.load("rf_injury_change_model.pkl")
104
 
105
- # Main application
106
  def main():
107
- st.title("NBA Player Performance Predictor πŸ€")
108
- st.markdown(
109
  """
110
- Use this tool to predict how a player's performance metrics might change
111
- if they experience a hypothetical injury.
112
  """
113
  )
114
 
115
- # Load data and model
116
  player_data = load_player_data()
117
  rf_model = load_rf_model()
118
 
119
  # Sidebar inputs
120
- with st.sidebar:
121
- st.header("Player & Injury Inputs")
122
- player_list = sorted(player_data['player_name'].dropna().unique())
123
- player_name = st.selectbox("Select Player", player_list)
124
-
125
- if player_name:
126
- # Filter data for the selected player
127
- player_row = player_data[player_data['player_name'] == player_name]
128
- team_name = player_row.iloc[0]['team_abbreviation']
 
 
129
  position = player_row.iloc[0]['position']
130
- stats_columns = ['age', 'player_height', 'player_weight']
 
 
131
 
132
- st.write(f"**Position**: {position}")
133
- st.write(f"**Team**: {team_name}")
 
 
 
 
134
 
135
- # Player stats
136
- default_stats = {stat: player_row.iloc[0][stat] for stat in stats_columns}
137
  for stat in default_stats.keys():
138
- default_stats[stat] = st.number_input(f"{stat}", value=default_stats[stat])
139
 
140
  # Injury details
141
- injury_type = st.selectbox("Select Hypothetical Injury", injury_types)
142
- default_days_injured = average_days_injured.get(injury_type, 30)
143
- days_injured = st.slider("Estimated Days Injured", 0, 365, int(default_days_injured))
144
- injury_occurrences = st.number_input("Injury Occurrences", min_value=0, value=1)
145
-
146
- # Prepare data for prediction
 
 
 
 
 
 
 
147
  input_data = pd.DataFrame([{
148
  "days_injured": days_injured,
149
  "injury_occurrences": injury_occurrences,
150
- "position": position_mapping.get(position, 0),
151
- "injury_type": injury_type,
152
  **default_stats
153
  }])
154
 
155
- # One-hot encode injury type
156
- input_data = pd.get_dummies(input_data, columns=["injury_type"], drop_first=True)
157
 
158
- # Align with model's feature names
159
- expected_features = rf_model.feature_names_in_
160
- for feature in expected_features:
161
- if feature not in input_data.columns:
162
- input_data[feature] = 0
163
 
164
- # Ensure columns are in the same order as the model's feature names
165
- input_data = input_data[expected_features]
 
166
 
167
- # Predict and display results
168
- st.header("Prediction Results")
169
- if st.sidebar.button("Predict"):
170
- try:
171
  predictions = rf_model.predict(input_data)
172
- prediction_columns = ["Predicted Change in PTS", "Predicted Change in REB", "Predicted Change in AST"]
173
  st.subheader("Predicted Post-Injury Performance")
174
  st.write("Based on the inputs, here are the predicted metrics:")
175
  st.table(pd.DataFrame(predictions, columns=prediction_columns))
176
- except FileNotFoundError:
177
- st.error("Model file not found.")
178
- except ValueError as e:
179
- st.error(f"Error during prediction: {e}")
 
 
 
 
 
180
 
181
  if __name__ == "__main__":
182
- main()
 
1
  import streamlit as st
2
  import pandas as pd
3
  import joblib
4
+ from sklearn.ensemble import RandomForestRegressor
5
  import plotly.express as px
 
 
 
 
 
 
 
 
6
 
7
  # Mapping for position to numeric values
8
  position_mapping = {
9
+ "PG": 1.0, # Point Guard
10
+ "SG": 2.0, # Shooting Guard
11
+ "SF": 3.0, # Small Forward
12
+ "PF": 4.0, # Power Forward
13
+ "C": 5.0, # Center
14
  }
15
 
16
+ # Predefined injury types
17
  injury_types = [
18
+ "foot fracture injury",
19
+ "hip flexor surgery injury",
20
+ "calf strain injury",
21
+ "quad injury injury",
22
+ "shoulder sprain injury",
23
+ "foot sprain injury",
24
+ "torn rotator cuff injury injury",
25
+ "torn mcl injury",
26
+ "hip flexor strain injury",
27
+ "fractured leg injury",
28
+ "sprained mcl injury",
29
+ "ankle sprain injury",
30
+ "hamstring injury injury",
31
+ "meniscus tear injury",
32
+ "torn hamstring injury",
33
+ "dislocated shoulder injury",
34
+ "ankle fracture injury",
35
+ "fractured hand injury",
36
+ "bone spurs injury",
37
+ "acl tear injury",
38
+ "hip labrum injury",
39
+ "back surgery injury",
40
+ "arm injury injury",
41
+ "torn shoulder labrum injury",
42
+ "lower back spasm injury"
43
  ]
44
 
45
+ # Injury average days dictionary
46
  average_days_injured = {
47
  "foot fracture injury": 207.666667,
48
  "hip flexor surgery injury": 256.000000,
 
71
  "lower back spasm injury": 234.000000,
72
  }
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+
76
+
77
+
78
+ # Load player dataset
79
  @st.cache_resource
80
  def load_player_data():
81
+ return pd.read_csv("/Users/laraschuman/Desktop/CTP-Project/player_data.csv")
82
 
83
+ # Load Random Forest model
84
  @st.cache_resource
85
  def load_rf_model():
86
+ return joblib.load("/Users/laraschuman/Desktop/CTP-Project/rf_injury_change_model.pkl")
87
 
88
+ # Main Streamlit app
89
  def main():
90
+ st.title("NBA Player Performance Predictor")
91
+ st.write(
92
  """
93
+ Predict how a player's performance metrics (e.g., points, rebounds, assists) might change
94
+ if a hypothetical injury occurs, based on their position and other factors.
95
  """
96
  )
97
 
98
+ # Load player data
99
  player_data = load_player_data()
100
  rf_model = load_rf_model()
101
 
102
  # Sidebar inputs
103
+ st.sidebar.header("Player and Injury Input")
104
+
105
+ # Dropdown for player selection
106
+ player_list = sorted(player_data['player_name'].dropna().unique())
107
+ player_name = st.sidebar.selectbox("Select Player", player_list)
108
+
109
+ if player_name:
110
+ # Retrieve player details
111
+ player_row = player_data[player_data['player_name'] == player_name]
112
+
113
+ if not player_row.empty:
114
  position = player_row.iloc[0]['position']
115
+ position_numeric = position_mapping.get(position, 0)
116
+
117
+ st.sidebar.write(f"**Position**: {position} (Numeric: {position_numeric})")
118
 
119
+ # Default values for features
120
+ stats_columns = ['age', 'player_height', 'player_weight']
121
+ default_stats = {
122
+ stat: player_row.iloc[0][stat] if stat in player_row.columns else 0
123
+ for stat in stats_columns
124
+ }
125
 
126
+ # Allow manual adjustment of stats
 
127
  for stat in default_stats.keys():
128
+ default_stats[stat] = st.sidebar.number_input(f"{stat}", value=default_stats[stat])
129
 
130
  # Injury details
131
+ injury_type = st.sidebar.selectbox("Select Hypothetical Injury", injury_types)
132
+ # Replace slider with default average based on injury type
133
+ default_days_injured = average_days_injured[injury_type] or 30 # Use 30 if `None`
134
+ days_injured = st.sidebar.slider(
135
+ "Estimated Days Injured",
136
+ 0,
137
+ 365,
138
+ int(default_days_injured),
139
+ help=f"Default days for {injury_type}: {int(default_days_injured) if default_days_injured else 'N/A'}"
140
+ )
141
+ injury_occurrences = st.sidebar.number_input("Injury Occurrences", min_value=0, value=1)
142
+
143
+ # Prepare input data
144
  input_data = pd.DataFrame([{
145
  "days_injured": days_injured,
146
  "injury_occurrences": injury_occurrences,
147
+ "position": position_numeric,
148
+ "injury_type": injury_type, # Include the selected injury type
149
  **default_stats
150
  }])
151
 
152
+ # Encode injury type
153
+ input_data["injury_type"] = pd.factorize(input_data["injury_type"])[0]
154
 
155
+ # Load Random Forest model
156
+ try:
157
+ rf_model = load_rf_model()
 
 
158
 
159
+ # Align input data with the model's feature names
160
+ expected_features = rf_model.feature_names_in_
161
+ input_data = input_data.reindex(columns=rf_model.feature_names_in_, fill_value=0)
162
 
163
+ # Predict and display results
164
+ if st.sidebar.button("Predict"):
 
 
165
  predictions = rf_model.predict(input_data)
166
+ prediction_columns = ["Predicted Change in PTS", "Predicted Change in REB", "Predicted Change inAST"]
167
  st.subheader("Predicted Post-Injury Performance")
168
  st.write("Based on the inputs, here are the predicted metrics:")
169
  st.table(pd.DataFrame(predictions, columns=prediction_columns))
170
+ except FileNotFoundError:
171
+ st.error("Model file not found.")
172
+ except ValueError as e:
173
+ st.error(f"Error during prediction: {e}")
174
+
175
+ else:
176
+ st.sidebar.error("Player details not found in the dataset.")
177
+ else:
178
+ st.sidebar.error("Please select a player to view details.")
179
 
180
  if __name__ == "__main__":
181
+ main()