Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import joblib | |
from sklearn.ensemble import RandomForestRegressor | |
import plotly.express as px | |
from sklearn.ensemble import RandomForestRegressor | |
import plotly.graph_objects as go | |
from PIL import Image | |
import plotly.express as px | |
# Set the page configuration | |
st.set_page_config( | |
page_title="NBA Player Performance Predictor", | |
page_icon="๐", | |
layout="centered" | |
) | |
# Custom CSS for vibrant NBA sidebar header | |
st.markdown( | |
""" | |
<style> | |
body { | |
background: linear-gradient(to bottom, #0033a0, #ed174c); /* NBA team colors gradient */ | |
font-family: 'Trebuchet MS', sans-serif; | |
margin: 0; | |
padding: 0; | |
color: white; /* Set text color to white */ | |
} | |
.sidebar .sidebar-content { | |
background: linear-gradient(to bottom, #4B0082, #1E90FF); /* Purple to blue gradient */ | |
border-radius: 10px; | |
padding: 10px; | |
color: #ffffff; /* Set sidebar text color to white */ | |
} | |
.sidebar h2 { | |
background: linear-gradient(to right, #FF1493, #FF4500); /* Pink to red gradient */ | |
color: white; /* Text color */ | |
padding: 10px; | |
text-align: center; | |
font-size: 20px; | |
font-weight: bold; | |
border-radius: 5px; | |
text-shadow: 2px 2px #000000; /* Add shadow for better visibility */ | |
margin-bottom: 15px; | |
} | |
.stButton > button { | |
background-color: #ffcc00; /* Bold yellow */ | |
color: #0033a0; /* Button text color */ | |
border: none; | |
border-radius: 5px; | |
padding: 10px 15px; | |
font-size: 16px; | |
transition: background-color 0.3s ease; | |
} | |
.stButton > button:hover { | |
background-color: #ffc107; /* Brighter yellow */ | |
box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.2); | |
} | |
.stMarkdown h1, .stMarkdown h2, .stMarkdown h3 { | |
color: #ffffff; /* Set headings color to white */ | |
text-shadow: 2px 2px #000000; /* Add shadow for better visibility */ | |
} | |
.block-container { | |
border-radius: 10px; | |
padding: 20px; | |
background-color: rgba(0, 0, 0, 0.8); /* Dark semi-transparent background */ | |
color: #ffffff; /* Ensure text inside the container is white */ | |
} | |
.dataframe { | |
background-color: rgba(255, 255, 255, 0.1); /* Transparent table background */ | |
color: #ffffff; /* Table text color */ | |
border-radius: 10px; | |
} | |
.stPlotlyChart { | |
background-color: rgba(0, 0, 0, 0.8); /* Match dark theme */ | |
padding: 10px; | |
border-radius: 10px; | |
box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.5); | |
} | |
.styled-table { | |
width: 100%; | |
border-collapse: collapse; | |
margin: 25px 0; | |
font-size: 18px; | |
text-align: left; | |
border-radius: 5px 5px 0 0; | |
overflow: hidden; | |
color: #ffffff; /* Table text color */ | |
} | |
.styled-table th, .styled-table td { | |
padding: 12px 15px; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
team_logo_paths = { | |
"Cleveland Cavaliers": "Clevelan-Cavaliers-logo-2022.png", | |
"Atlanta Hawks": "nba-atlanta-hawks-logo.png", | |
"Boston Celtics": "nba-boston-celtics-logo.png", | |
"Brooklyn Nets": "nba-brooklyn-nets-logo.png", | |
"Charlotte Hornets": "nba-charlotte-hornets-logo.png", | |
"Chicago Bulls": "nba-chicago-bulls-logo.png", | |
"Dallas Mavericks": "nba-dallas-mavericks-logo.png", | |
"Denver Nuggets": "nba-denver-nuggets-logo-2018.png", | |
"Detroit Pistons": "nba-detroit-pistons-logo.png", | |
"Golden State Warriors": "nba-golden-state-warriors-logo-2020.png", | |
"Houston Rockets": "nba-houston-rockets-logo-2020.png", | |
"Indiana Pacers": "nba-indiana-pacers-logo.png", | |
"LA Clippers": "nba-la-clippers-logo.png", | |
"Los Angeles Lakers": "nba-los-angeles-lakers-logo.png", | |
"Memphis Grizzlies": "nba-memphis-grizzlies-logo.png", | |
"Miami Heat": "nba-miami-heat-logo.png", | |
"Milwaukee Bucks": "nba-milwaukee-bucks-logo.png", | |
"Minnesota Timberwolves": "nba-minnesota-timberwolves-logo.png", | |
"New Orleans Pelicans": "nba-new-orleans-pelicans-logo.png", | |
"New York Knicks": "nba-new-york-knicks-logo.png", | |
"Oklahoma City Thunder": "nba-oklahoma-city-thunder-logo.png", | |
"Orlando Magic": "nba-orlando-magic-logo.png", | |
"Philadelphia 76ers": "nba-philadelphia-76ers-logo.png", | |
"Phoenix Suns": "nba-phoenix-suns-logo.png", | |
"Portland Trail Blazers": "nba-portland-trail-blazers-logo.png", | |
"Sacramento Kings": "nba-sacramento-kings-logo.png", | |
"San Antonio Spurs": "nba-san-antonio-spurs-logo.png", | |
"Toronto Raptors": "nba-toronto-raptors-logo-2020.png", | |
"Utah Jazz": "nba-utah-jazz-logo.png", | |
"Washington Wizards": "nba-washington-wizards-logo.png", | |
} | |
# Mapping for position to numeric values | |
position_mapping = { | |
"PG": 1.0, # Point Guard | |
"SG": 2.0, # Shooting Guard | |
"SF": 3.0, # Small Forward | |
"PF": 4.0, # Power Forward | |
"C": 5.0, # Center | |
} | |
# Predefined injury types | |
injury_types = [ | |
"foot fracture injury", | |
"hip flexor surgery injury", | |
"calf strain injury", | |
"quad injury injury", | |
"shoulder sprain injury", | |
"foot sprain injury", | |
"torn rotator cuff injury injury", | |
"torn mcl injury", | |
"hip flexor strain injury", | |
"fractured leg injury", | |
"sprained mcl injury", | |
"ankle sprain injury", | |
"hamstring injury injury", | |
"meniscus tear injury", | |
"torn hamstring injury", | |
"dislocated shoulder injury", | |
"ankle fracture injury", | |
"fractured hand injury", | |
"bone spurs injury", | |
"acl tear injury", | |
"hip labrum injury", | |
"back surgery injury", | |
"arm injury injury", | |
"torn shoulder labrum injury", | |
"lower back spasm injury" | |
] | |
# Injury average days dictionary | |
average_days_injured = { | |
"foot fracture injury": 207.666667, | |
"hip flexor surgery injury": 256.000000, | |
"calf strain injury": 236.000000, | |
"quad injury injury": 283.000000, | |
"shoulder sprain injury": 259.500000, | |
"foot sprain injury": 294.000000, | |
"torn rotator cuff injury injury": 251.500000, | |
"torn mcl injury": 271.000000, | |
"hip flexor strain injury": 253.000000, | |
"fractured leg injury": 250.250000, | |
"sprained mcl injury": 228.666667, | |
"ankle sprain injury": 231.333333, | |
"hamstring injury injury": 220.000000, | |
"meniscus tear injury": 201.250000, | |
"torn hamstring injury": 187.666667, | |
"dislocated shoulder injury": 269.000000, | |
"ankle fracture injury": 114.500000, | |
"fractured hand injury": 169.142857, | |
"bone spurs injury": 151.500000, | |
"acl tear injury": 268.000000, | |
"hip labrum injury": 247.500000, | |
"back surgery injury": 215.800000, | |
"arm injury injury": 303.666667, | |
"torn shoulder labrum injury": 195.666667, | |
"lower back spasm injury": 234.000000, | |
} | |
# Load player dataset | |
def load_player_data(): | |
return pd.read_csv("player_data.csv") | |
# Load Random Forest model | |
def load_rf_model(): | |
return joblib.load("rf_injury_change_model.pkl") | |
# Main Streamlit app | |
def main(): | |
st.title("NBA Player Performance Predictor ๐") | |
st.write( | |
""" | |
Welcome to the **NBA Player Performance Predictor**! This app helps predict changes in a player's performance metrics | |
after experiencing a hypothetical injury. Simply input the details and see the magic happen! | |
""" | |
) | |
# Load player data and model | |
player_data = load_player_data() | |
rf_model = load_rf_model() | |
# Sidebar inputs | |
st.sidebar.markdown( | |
""" | |
<div style="padding: 10px; background: linear-gradient(to right, #6a11cb, #2575fc); color: white; border-radius: 10px;"> | |
<h3>Player Details</h3> | |
</div> | |
""", | |
unsafe_allow_html=True | |
) | |
player_list = sorted(player_data['player_name'].dropna().unique()) | |
player_name = st.sidebar.selectbox("Select Player", player_list) | |
if player_name: | |
player_row = player_data[player_data['player_name'] == player_name] | |
team_name = player_row.iloc[0]['team_abbreviation'] | |
position = player_row.iloc[0]['position'] | |
stats_columns = ['age', 'player_height', 'player_weight'] | |
default_stats = {stat: player_row.iloc[0][stat] for stat in stats_columns} | |
for stat in default_stats.keys(): | |
default_stats[stat] = st.sidebar.number_input(f"{stat}", value=default_stats[stat]) | |
injury_type = st.sidebar.selectbox("Select Hypothetical Injury", injury_types) | |
default_days_injured = average_days_injured.get(injury_type, 30) | |
days_injured = st.sidebar.slider("Estimated Days Injured", 0, 365, int(default_days_injured)) | |
injury_occurrences = st.sidebar.number_input("Injury Occurrences", min_value=0, value=1) | |
input_data = pd.DataFrame([{ | |
"days_injured": days_injured, | |
"injury_occurrences": injury_occurrences, | |
"position": position_mapping.get(position, 0), | |
"injury_type": injury_type, | |
**default_stats | |
}]) | |
input_data["injury_type"] = pd.factorize(input_data["injury_type"])[0] | |
st.divider() | |
st.header("Player Overview") | |
col1, col2 = st.columns([1, 2]) | |
with col1: | |
st.subheader("Player Details") | |
st.metric("Age", default_stats['age']) | |
st.metric("Height (cm)", default_stats['player_height']) | |
st.metric("Weight (kg)", default_stats['player_weight']) | |
with col2: | |
if team_name in team_logo_paths: | |
logo_path = team_logo_paths[team_name] | |
try: | |
logo_image = Image.open(logo_path) | |
st.image(logo_image, caption=f"{team_name} Logo", use_column_width=True) | |
except FileNotFoundError: | |
st.error(f"Logo for {team_name} not found.") | |
if st.sidebar.button("Predict ๐ฎ"): | |
predictions = rf_model.predict(input_data) | |
prediction_columns = ["Predicted Change in PTS", "Predicted Change in REB", "Predicted Change in AST"] | |
prediction_df = pd.DataFrame(predictions, columns=prediction_columns) | |
st.subheader("Predicted Post-Injury Performance") | |
st.write(prediction_df) | |
fig = go.Figure() | |
for col in prediction_columns: | |
fig.add_trace(go.Bar( | |
x=[col], | |
y=prediction_df[col], | |
name=col, | |
marker=dict(color=px.colors.qualitative.Plotly[prediction_columns.index(col)]) | |
)) | |
fig.update_layout( | |
title="Predicted Performance Changes", | |
xaxis_title="Metrics", | |
yaxis_title="Change Value", | |
template="plotly_dark" | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
if __name__ == "__main__": | |
main() |