lrschuman17 commited on
Commit
99c7136
·
verified ·
1 Parent(s): 5663fe1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -133
app.py CHANGED
@@ -3,180 +3,120 @@ import pandas as pd
3
  import joblib
4
  from sklearn.ensemble import RandomForestRegressor
5
  import plotly.express as px
 
6
 
7
  # Mapping for position to numeric values
8
  position_mapping = {
9
- "PG": 1.0, # Point Guard
10
- "SG": 2.0, # Shooting Guard
11
- "SF": 3.0, # Small Forward
12
- "PF": 4.0, # Power Forward
13
- "C": 5.0, # Center
14
  }
15
 
16
  # Predefined injury types
17
  injury_types = [
18
- "foot fracture injury",
19
- "hip flexor surgery injury",
20
- "calf strain injury",
21
- "quad injury injury",
22
- "shoulder sprain injury",
23
- "foot sprain injury",
24
- "torn rotator cuff injury injury",
25
- "torn mcl injury",
26
- "hip flexor strain injury",
27
- "fractured leg injury",
28
- "sprained mcl injury",
29
- "ankle sprain injury",
30
- "hamstring injury injury",
31
- "meniscus tear injury",
32
- "torn hamstring injury",
33
- "dislocated shoulder injury",
34
- "ankle fracture injury",
35
- "fractured hand injury",
36
- "bone spurs injury",
37
- "acl tear injury",
38
- "hip labrum injury",
39
- "back surgery injury",
40
- "arm injury injury",
41
- "torn shoulder labrum injury",
42
- "lower back spasm injury"
43
  ]
44
 
45
- # Injury average days dictionary
46
  average_days_injured = {
47
  "foot fracture injury": 207.666667,
48
- "hip flexor surgery injury": 256.000000,
49
- "calf strain injury": 236.000000,
50
- "quad injury injury": 283.000000,
51
- "shoulder sprain injury": 259.500000,
52
- "foot sprain injury": 294.000000,
53
- "torn rotator cuff injury injury": 251.500000,
54
- "torn mcl injury": 271.000000,
55
- "hip flexor strain injury": 253.000000,
56
- "fractured leg injury": 250.250000,
57
- "sprained mcl injury": 228.666667,
58
- "ankle sprain injury": 231.333333,
59
- "hamstring injury injury": 220.000000,
60
- "meniscus tear injury": 201.250000,
61
- "torn hamstring injury": 187.666667,
62
- "dislocated shoulder injury": 269.000000,
63
- "ankle fracture injury": 114.500000,
64
- "fractured hand injury": 169.142857,
65
- "bone spurs injury": 151.500000,
66
- "acl tear injury": 268.000000,
67
- "hip labrum injury": 247.500000,
68
- "back surgery injury": 215.800000,
69
- "arm injury injury": 303.666667,
70
- "torn shoulder labrum injury": 195.666667,
71
  "lower back spasm injury": 234.000000,
72
  }
73
 
74
-
75
-
76
-
77
-
78
- # Load player dataset
79
  @st.cache_resource
80
  def load_player_data():
81
  return pd.read_csv("player_data.csv")
82
 
83
- # Load Random Forest model
84
  @st.cache_resource
85
  def load_rf_model():
86
  return joblib.load("rf_injury_change_model.pkl")
87
 
88
  # Main Streamlit app
89
  def main():
90
- st.title("NBA Player Performance Predictor")
91
- st.write(
92
- """
93
- Predict how a player's performance metrics (e.g., points, rebounds, assists) might change
94
- if a hypothetical injury occurs, based on their position and other factors.
95
- """
96
- )
97
-
98
- # Load player data
99
  player_data = load_player_data()
100
  rf_model = load_rf_model()
101
 
102
- # Sidebar inputs
103
- st.sidebar.header("Player and Injury Input")
104
-
105
- # Dropdown for player selection
106
- player_list = sorted(player_data['player_name'].dropna().unique())
107
- player_name = st.sidebar.selectbox("Select Player", player_list)
108
 
109
- if player_name:
110
- # Retrieve player details
111
- player_row = player_data[player_data['player_name'] == player_name]
112
-
113
- if not player_row.empty:
114
  position = player_row.iloc[0]['position']
115
  position_numeric = position_mapping.get(position, 0)
116
-
117
- st.sidebar.write(f"**Position**: {position} (Numeric: {position_numeric})")
118
-
119
- # Default values for features
120
  stats_columns = ['age', 'player_height', 'player_weight']
121
- default_stats = {
122
- stat: player_row.iloc[0][stat] if stat in player_row.columns else 0
123
- for stat in stats_columns
124
- }
125
 
126
- # Allow manual adjustment of stats
127
  for stat in default_stats.keys():
128
- default_stats[stat] = st.sidebar.number_input(f"{stat}", value=default_stats[stat])
129
-
130
- # Injury details
131
- injury_type = st.sidebar.selectbox("Select Hypothetical Injury", injury_types)
132
- # Replace slider with default average based on injury type
133
- default_days_injured = average_days_injured[injury_type] or 30 # Use 30 if `None`
134
- days_injured = st.sidebar.slider(
135
- "Estimated Days Injured",
136
- 0,
137
- 365,
138
- int(default_days_injured),
139
- help=f"Default days for {injury_type}: {int(default_days_injured) if default_days_injured else 'N/A'}"
140
- )
141
- injury_occurrences = st.sidebar.number_input("Injury Occurrences", min_value=0, value=1)
142
-
143
- # Prepare input data
144
  input_data = pd.DataFrame([{
145
  "days_injured": days_injured,
146
  "injury_occurrences": injury_occurrences,
147
  "position": position_numeric,
148
- "injury_type": injury_type, # Include the selected injury type
149
  **default_stats
150
  }])
151
-
152
- # Encode injury type
153
  input_data["injury_type"] = pd.factorize(input_data["injury_type"])[0]
154
 
155
- # Load Random Forest model
156
- try:
157
- rf_model = load_rf_model()
158
-
159
- # Align input data with the model's feature names
160
- expected_features = rf_model.feature_names_in_
161
- input_data = input_data.reindex(columns=rf_model.feature_names_in_, fill_value=0)
162
-
163
- # Predict and display results
164
- if st.sidebar.button("Predict"):
165
- predictions = rf_model.predict(input_data)
166
- predictions = [round(float(pred), 2) for pred in predictions]
167
- prediction_columns = ["Predicted Change in PTS", "Predicted Change in REB", "Predicted Change inAST"]
168
- st.subheader("Predicted Post-Injury Performance")
169
- st.write("Based on the inputs, here are the predicted metrics:")
170
- st.table(pd.DataFrame(predictions, columns=prediction_columns))
171
- except FileNotFoundError:
172
- st.error("Model file not found.")
173
- except ValueError as e:
174
- st.error(f"Error during prediction: {e}")
175
-
176
- else:
177
- st.sidebar.error("Player details not found in the dataset.")
178
- else:
179
- st.sidebar.error("Please select a player to view details.")
 
 
 
 
 
 
 
 
180
 
181
  if __name__ == "__main__":
182
  main()
 
3
  import joblib
4
  from sklearn.ensemble import RandomForestRegressor
5
  import plotly.express as px
6
+ import plotly.graph_objects as go
7
 
8
  # Mapping for position to numeric values
9
  position_mapping = {
10
+ "PG": 1.0,
11
+ "SG": 2.0,
12
+ "SF": 3.0,
13
+ "PF": 4.0,
14
+ "C": 5.0,
15
  }
16
 
17
  # Predefined injury types
18
  injury_types = [
19
+ "foot fracture injury", "hip flexor surgery injury", "calf strain injury",
20
+ "quad injury injury", "shoulder sprain injury", "foot sprain injury",
21
+ "torn rotator cuff injury injury", "torn mcl injury", "hip flexor strain injury",
22
+ "fractured leg injury", "sprained mcl injury", "ankle sprain injury",
23
+ "hamstring injury injury", "meniscus tear injury", "torn hamstring injury",
24
+ "dislocated shoulder injury", "ankle fracture injury", "fractured hand injury",
25
+ "bone spurs injury", "acl tear injury", "hip labrum injury", "back surgery injury",
26
+ "arm injury injury", "torn shoulder labrum injury", "lower back spasm injury"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  ]
28
 
 
29
  average_days_injured = {
30
  "foot fracture injury": 207.666667,
31
+ # ... (other injuries with their averages)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  "lower back spasm injury": 234.000000,
33
  }
34
 
 
 
 
 
 
35
  @st.cache_resource
36
  def load_player_data():
37
  return pd.read_csv("player_data.csv")
38
 
 
39
  @st.cache_resource
40
  def load_rf_model():
41
  return joblib.load("rf_injury_change_model.pkl")
42
 
43
  # Main Streamlit app
44
  def main():
45
+ st.title("NBA Player Performance Predictor 🏀")
46
+ st.write("""
47
+ Use this tool to predict how a player's performance metrics might change
48
+ if they experience a hypothetical injury.
49
+ """)
50
+
 
 
 
51
  player_data = load_player_data()
52
  rf_model = load_rf_model()
53
 
54
+ # Layout: Sidebar for Inputs
55
+ with st.sidebar:
56
+ st.header("Player & Injury Inputs")
57
+ player_list = sorted(player_data['player_name'].dropna().unique())
58
+ player_name = st.selectbox("Select Player", player_list)
 
59
 
60
+ # Player Details
61
+ if player_name:
62
+ player_row = player_data[player_data['player_name'] == player_name]
 
 
63
  position = player_row.iloc[0]['position']
64
  position_numeric = position_mapping.get(position, 0)
65
+ st.write(f"**Position**: {position} (Numeric: {position_numeric})")
66
+
 
 
67
  stats_columns = ['age', 'player_height', 'player_weight']
68
+ default_stats = {stat: player_row.iloc[0][stat] for stat in stats_columns}
 
 
 
69
 
 
70
  for stat in default_stats.keys():
71
+ default_stats[stat] = st.number_input(f"{stat}", value=default_stats[stat])
72
+
73
+ injury_type = st.selectbox("Select Hypothetical Injury", injury_types)
74
+ default_days_injured = average_days_injured.get(injury_type, 30)
75
+ days_injured = st.slider("Estimated Days Injured", 0, 365, int(default_days_injured))
76
+ injury_occurrences = st.number_input("Injury Occurrences", min_value=0, value=1)
77
+
 
 
 
 
 
 
 
 
 
78
  input_data = pd.DataFrame([{
79
  "days_injured": days_injured,
80
  "injury_occurrences": injury_occurrences,
81
  "position": position_numeric,
82
+ "injury_type": injury_type,
83
  **default_stats
84
  }])
 
 
85
  input_data["injury_type"] = pd.factorize(input_data["injury_type"])[0]
86
 
87
+ # Main Layout
88
+ col1, col2 = st.columns(2)
89
+
90
+ with col1:
91
+ st.subheader("Player Details")
92
+ st.metric("Age", default_stats['age'])
93
+ st.metric("Height (cm)", default_stats['player_height'])
94
+ st.metric("Weight (kg)", default_stats['player_weight'])
95
+ st.image("player_placeholder.png", caption=f"{player_name}", width=200) # Placeholder image
96
+
97
+ with col2:
98
+ st.subheader("Prediction Results")
99
+
100
+ if st.button("Predict"):
101
+ predictions = rf_model.predict(input_data)
102
+ predictions = [round(float(pred), 2) for pred in predictions]
103
+
104
+ prediction_columns = ["Predicted Change in PTS", "Predicted Change in REB", "Predicted Change in AST"]
105
+ result_df = pd.DataFrame([predictions], columns=prediction_columns)
106
+
107
+ # Bar Chart for Visualizing Predictions
108
+ fig = go.Figure(data=[
109
+ go.Bar(
110
+ x=prediction_columns,
111
+ y=predictions,
112
+ marker_color=["green" if val > 0 else "red" for val in predictions]
113
+ )
114
+ ])
115
+ fig.update_layout(title="Predicted Performance Changes", xaxis_title="Metrics", yaxis_title="Change")
116
+ st.plotly_chart(fig)
117
+
118
+ # Display predictions as a table
119
+ st.table(result_df)
120
 
121
  if __name__ == "__main__":
122
  main()