lrschuman17 commited on
Commit
373112b
·
verified ·
1 Parent(s): 05c4948

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -0
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import joblib
4
+ from sklearn.ensemble import RandomForestRegressor
5
+ import plotly.express as px
6
+
7
+ # Mapping for position to numeric values
8
+ position_mapping = {
9
+ "PG": 1.0, # Point Guard
10
+ "SG": 2.0, # Shooting Guard
11
+ "SF": 3.0, # Small Forward
12
+ "PF": 4.0, # Power Forward
13
+ "C": 5.0, # Center
14
+ }
15
+
16
+ # Predefined injury types
17
+ injury_types = [
18
+ "foot fracture injury",
19
+ "hip flexor surgery injury",
20
+ "calf strain injury",
21
+ "quad injury injury",
22
+ "shoulder sprain injury",
23
+ "foot sprain injury",
24
+ "torn rotator cuff injury injury",
25
+ "torn mcl injury",
26
+ "hip flexor strain injury",
27
+ "fractured leg injury",
28
+ "sprained mcl injury",
29
+ "ankle sprain injury",
30
+ "hamstring injury injury",
31
+ "meniscus tear injury",
32
+ "torn hamstring injury",
33
+ "dislocated shoulder injury",
34
+ "ankle fracture injury",
35
+ "fractured hand injury",
36
+ "bone spurs injury",
37
+ "acl tear injury",
38
+ "hip labrum injury",
39
+ "back surgery injury",
40
+ "arm injury injury",
41
+ "torn shoulder labrum injury",
42
+ "lower back spasm injury"
43
+ ]
44
+
45
+ # Injury average days dictionary
46
+ average_days_injured = {
47
+ "foot fracture injury": 207.666667,
48
+ "hip flexor surgery injury": 256.000000,
49
+ "calf strain injury": 236.000000,
50
+ "quad injury injury": 283.000000,
51
+ "shoulder sprain injury": 259.500000,
52
+ "foot sprain injury": 294.000000,
53
+ "torn rotator cuff injury injury": 251.500000,
54
+ "torn mcl injury": 271.000000,
55
+ "hip flexor strain injury": 253.000000,
56
+ "fractured leg injury": 250.250000,
57
+ "sprained mcl injury": 228.666667,
58
+ "ankle sprain injury": 231.333333,
59
+ "hamstring injury injury": 220.000000,
60
+ "meniscus tear injury": 201.250000,
61
+ "torn hamstring injury": 187.666667,
62
+ "dislocated shoulder injury": 269.000000,
63
+ "ankle fracture injury": 114.500000,
64
+ "fractured hand injury": 169.142857,
65
+ "bone spurs injury": 151.500000,
66
+ "acl tear injury": 268.000000,
67
+ "hip labrum injury": 247.500000,
68
+ "back surgery injury": 215.800000,
69
+ "arm injury injury": 303.666667,
70
+ "torn shoulder labrum injury": 195.666667,
71
+ "lower back spasm injury": 234.000000,
72
+ }
73
+
74
+
75
+
76
+
77
+
78
+ # Load player dataset
79
+ @st.cache_resource
80
+ def load_player_data():
81
+ return pd.read_csv("/Users/laraschuman/Desktop/CTP-Project/player_data.csv")
82
+
83
+ # Load Random Forest model
84
+ @st.cache_resource
85
+ def load_rf_model():
86
+ return joblib.load("/Users/laraschuman/Desktop/CTP-Project/rf_injury_change_model.pkl")
87
+
88
+ # Main Streamlit app
89
+ def main():
90
+ st.title("NBA Player Performance Predictor")
91
+ st.write(
92
+ """
93
+ Predict how a player's performance metrics (e.g., points, rebounds, assists) might change
94
+ if a hypothetical injury occurs, based on their position and other factors.
95
+ """
96
+ )
97
+
98
+ # Load player data
99
+ player_data = load_player_data()
100
+ rf_model = load_rf_model()
101
+
102
+ # Sidebar inputs
103
+ st.sidebar.header("Player and Injury Input")
104
+
105
+ # Dropdown for player selection
106
+ player_list = sorted(player_data['player_name'].dropna().unique())
107
+ player_name = st.sidebar.selectbox("Select Player", player_list)
108
+
109
+ if player_name:
110
+ # Retrieve player details
111
+ player_row = player_data[player_data['player_name'] == player_name]
112
+
113
+ if not player_row.empty:
114
+ position = player_row.iloc[0]['position']
115
+ position_numeric = position_mapping.get(position, 0)
116
+
117
+ st.sidebar.write(f"**Position**: {position} (Numeric: {position_numeric})")
118
+
119
+ # Default values for features
120
+ stats_columns = ['age', 'player_height', 'player_weight']
121
+ default_stats = {
122
+ stat: player_row.iloc[0][stat] if stat in player_row.columns else 0
123
+ for stat in stats_columns
124
+ }
125
+
126
+ # Allow manual adjustment of stats
127
+ for stat in default_stats.keys():
128
+ default_stats[stat] = st.sidebar.number_input(f"{stat}", value=default_stats[stat])
129
+
130
+ # Injury details
131
+ injury_type = st.sidebar.selectbox("Select Hypothetical Injury", injury_types)
132
+ # Replace slider with default average based on injury type
133
+ default_days_injured = average_days_injured[injury_type] or 30 # Use 30 if `None`
134
+ days_injured = st.sidebar.slider(
135
+ "Estimated Days Injured",
136
+ 0,
137
+ 365,
138
+ int(default_days_injured),
139
+ help=f"Default days for {injury_type}: {int(default_days_injured) if default_days_injured else 'N/A'}"
140
+ )
141
+ injury_occurrences = st.sidebar.number_input("Injury Occurrences", min_value=0, value=1)
142
+
143
+ # Prepare input data
144
+ input_data = pd.DataFrame([{
145
+ "days_injured": days_injured,
146
+ "injury_occurrences": injury_occurrences,
147
+ "position": position_numeric,
148
+ "injury_type": injury_type, # Include the selected injury type
149
+ **default_stats
150
+ }])
151
+
152
+ # Encode injury type
153
+ input_data["injury_type"] = pd.factorize(input_data["injury_type"])[0]
154
+
155
+ # Load Random Forest model
156
+ try:
157
+ rf_model = load_rf_model()
158
+
159
+ # Align input data with the model's feature names
160
+ expected_features = rf_model.feature_names_in_
161
+ input_data = input_data.reindex(columns=rf_model.feature_names_in_, fill_value=0)
162
+
163
+ # Predict and display results
164
+ if st.sidebar.button("Predict"):
165
+ predictions = rf_model.predict(input_data)
166
+ prediction_columns = ["Predicted Change in PTS", "Predicted Change in REB", "Predicted Change inAST"]
167
+ st.subheader("Predicted Post-Injury Performance")
168
+ st.write("Based on the inputs, here are the predicted metrics:")
169
+ st.table(pd.DataFrame(predictions, columns=prediction_columns))
170
+ except FileNotFoundError:
171
+ st.error("Model file not found.")
172
+ except ValueError as e:
173
+ st.error(f"Error during prediction: {e}")
174
+
175
+ else:
176
+ st.sidebar.error("Player details not found in the dataset.")
177
+ else:
178
+ st.sidebar.error("Please select a player to view details.")
179
+
180
+ if __name__ == "__main__":
181
+ main()