Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import joblib | |
import plotly.express as px | |
from PIL import Image | |
# Set the page configuration | |
st.set_page_config( | |
page_title="NBA Player Performance Predictor π", | |
page_icon="π", | |
layout="centered" | |
) | |
# Mapping for position to numeric values | |
position_mapping = { | |
"PG": 1.0, | |
"SG": 2.0, | |
"SF": 3.0, | |
"PF": 4.0, | |
"C": 5.0, | |
} | |
# Injury types and average days | |
injury_types = [ | |
"foot fracture injury", "hip flexor surgery injury", "calf strain injury", | |
"quad injury injury", "shoulder sprain injury", "foot sprain injury", | |
"torn rotator cuff injury injury", "torn mcl injury", "hip flexor strain injury", | |
"fractured leg injury", "sprained mcl injury", "ankle sprain injury", | |
"hamstring injury injury", "meniscus tear injury", "torn hamstring injury", | |
"dislocated shoulder injury", "ankle fracture injury", "fractured hand injury", | |
"bone spurs injury", "acl tear injury", "hip labrum injury", "back surgery injury", | |
"arm injury injury", "torn shoulder labrum injury", "lower back spasm injury" | |
] | |
average_days_injured = { | |
"foot fracture injury": 207.666667, | |
"hip flexor surgery injury": 256.000000, | |
"calf strain injury": 236.000000, | |
"quad injury injury": 283.000000, | |
"shoulder sprain injury": 259.500000, | |
"foot sprain injury": 294.000000, | |
"torn rotator cuff injury injury": 251.500000, | |
"torn mcl injury": 271.000000, | |
"hip flexor strain injury": 253.000000, | |
"fractured leg injury": 250.250000, | |
"sprained mcl injury": 228.666667, | |
"ankle sprain injury": 231.333333, | |
"hamstring injury injury": 220.000000, | |
"meniscus tear injury": 201.250000, | |
"torn hamstring injury": 187.666667, | |
"dislocated shoulder injury": 269.000000, | |
"ankle fracture injury": 114.500000, | |
"fractured hand injury": 169.142857, | |
"bone spurs injury": 151.500000, | |
"acl tear injury": 268.000000, | |
"hip labrum injury": 247.500000, | |
"back surgery injury": 215.800000, | |
"arm injury injury": 303.666667, | |
"torn shoulder labrum injury": 195.666667, | |
"lower back spasm injury": 234.000000, | |
} | |
team_logo_paths = { | |
"Cleveland Cavaliers": "NBA_LOGOs/Clevelan-Cavaliers-logo-2022.png", | |
"Atlanta Hawks": "NBA_LOGOs/nba-atlanta-hawks-logo.png", | |
"Boston Celtics": "NBA_LOGOs/nba-boston-celtics-logo.png", | |
"Brooklyn Nets": "NBA_LOGOs/nba-brooklyn-nets-logo.png", | |
"Charlotte Hornets": "NBA_LOGOs/nba-charlotte-hornets-logo.png", | |
"Chicago Bulls": "NBA_LOGOs/nba-chicago-bulls-logo.png", | |
"Dallas Mavericks": "NBA_LOGOs/nba-dallas-mavericks-logo.png", | |
"Denver Nuggets": "NBA_LOGOs/nba-denver-nuggets-logo-2018.png", | |
"Detroit Pistons": "NBA_LOGOs/nba-detroit-pistons-logo.png", | |
"Golden State Warriors": "NBA_LOGOs/nba-golden-state-warriors-logo-2020.png", | |
"Houston Rockets": "NBA_LOGOs/nba-houston-rockets-logo-2020.png", | |
"Indiana Pacers": "NBA_LOGOs/nba-indiana-pacers-logo.png", | |
"LA Clippers": "NBA_LOGOs/nba-la-clippers-logo.png", | |
"Los Angeles Lakers": "NBA_LOGOs/nba-los-angeles-lakers-logo.png", | |
"Memphis Grizzlies": "NBA_LOGOs/nba-memphis-grizzlies-logo.png", | |
"Miami Heat": "NBA_LOGOs/nba-miami-heat-logo.png", | |
"Milwaukee Bucks": "NBA_LOGOs/nba-milwaukee-bucks-logo.png", | |
"Minnesota Timberwolves": "NBA_LOGOs/nba-minnesota-timberwolves-logo.png", | |
"New Orleans Pelicans": "NBA_LOGOs/nba-new-orleans-pelicans-logo.png", | |
"New York Knicks": "NBA_LOGOs/nba-new-york-knicks-logo.png", | |
"Oklahoma City Thunder": "NBA_LOGOs/nba-oklahoma-city-thunder-logo.png", | |
"Orlando Magic": "NBA_LOGOs/nba-orlando-magic-logo.png", | |
"Philadelphia 76ers": "NBA_LOGOs/nba-philadelphia-76ers-logo.png", | |
"Phoenix Suns": "NBA_LOGOs/nba-phoenix-suns-logo.png", | |
"Portland Trail Blazers": "NBA_LOGOs/nba-portland-trail-blazers-logo.png", | |
"Sacramento Kings": "NBA_LOGOs/nba-sacramento-kings-logo.png", | |
"San Antonio Spurs": "NBA_LOGOs/nba-san-antonio-spurs-logo.png", | |
"Toronto Raptors": "NBA_LOGOs/nba-toronto-raptors-logo-2020.png", | |
"Utah Jazz": "NBA_LOGOs/nba-utah-jazz-logo.png", | |
"Washington Wizards": "NBA_LOGOs/nba-washington-wizards-logo.png", | |
} | |
# Caching player data and model | |
def load_player_data(): | |
return pd.read_csv("player_data.csv") | |
def load_rf_model(): | |
return joblib.load("rf_injury_change_model.pkl") | |
# Main application | |
def main(): | |
st.title("NBA Player Performance Predictor π") | |
st.markdown( | |
""" | |
Use this tool to predict how a player's performance metrics might change | |
if they experience a hypothetical injury. | |
""" | |
) | |
# Load data and model | |
player_data = load_player_data() | |
rf_model = load_rf_model() | |
# Sidebar inputs | |
with st.sidebar: | |
st.header("Player & Injury Inputs") | |
player_list = sorted(player_data['player_name'].dropna().unique()) | |
player_name = st.selectbox("Select Player", player_list) | |
if player_name: | |
# Filter data for the selected player | |
player_row = player_data[player_data['player_name'] == player_name] | |
team_name = player_row.iloc[0]['team_abbreviation'] | |
position = player_row.iloc[0]['position'] | |
stats_columns = ['age', 'player_height', 'player_weight'] | |
st.write(f"**Position**: {position}") | |
st.write(f"**Team**: {team_name}") | |
# Player stats | |
default_stats = {stat: player_row.iloc[0][stat] for stat in stats_columns} | |
for stat in default_stats.keys(): | |
default_stats[stat] = st.number_input(f"{stat}", value=default_stats[stat]) | |
# Injury details | |
injury_type = st.selectbox("Select Hypothetical Injury", injury_types) | |
default_days_injured = average_days_injured.get(injury_type, 30) | |
days_injured = st.slider("Estimated Days Injured", 0, 365, int(default_days_injured)) | |
injury_occurrences = st.number_input("Injury Occurrences", min_value=0, value=1) | |
# Prepare data for prediction | |
input_data = pd.DataFrame([{ | |
"days_injured": days_injured, | |
"injury_occurrences": injury_occurrences, | |
"position": position_mapping.get(position, 0), | |
"injury_type": injury_type, | |
**default_stats | |
}]) | |
# One-hot encode injury type | |
input_data = pd.get_dummies(input_data, columns=["injury_type"], drop_first=True) | |
# Align with model's feature names | |
expected_features = rf_model.feature_names_in_ | |
for feature in expected_features: | |
if feature not in input_data.columns: | |
input_data[feature] = 0 | |
# Ensure columns are in the same order as the model's feature names | |
input_data = input_data[expected_features] | |
# Predict and display results | |
st.header("Prediction Results") | |
if st.sidebar.button("Predict"): | |
try: | |
predictions = rf_model.predict(input_data) | |
prediction_columns = ["Predicted Change in PTS", "Predicted Change in REB", "Predicted Change in AST"] | |
st.subheader("Predicted Post-Injury Performance") | |
st.write("Based on the inputs, here are the predicted metrics:") | |
st.table(pd.DataFrame(predictions, columns=prediction_columns)) | |
except FileNotFoundError: | |
st.error("Model file not found.") | |
except ValueError as e: | |
st.error(f"Error during prediction: {e}") | |
if __name__ == "__main__": | |
main() | |