Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import joblib | |
from sklearn.ensemble import RandomForestRegressor | |
import plotly.express as px | |
from sklearn.ensemble import RandomForestRegressor | |
import plotly.graph_objects as go | |
from PIL import Image | |
import plotly.express as px | |
# Set the page configuration | |
st.set_page_config( | |
page_title="NBA Player Performance Predictor", | |
page_icon="๐", | |
layout="centered" | |
) | |
# Custom CSS for vibrant NBA sidebar styling | |
st.markdown( | |
""" | |
<style> | |
body { | |
background: linear-gradient(to bottom, #0033a0, #ed174c); /* NBA team colors gradient */ | |
font-family: 'Trebuchet MS', sans-serif; | |
margin: 0; | |
padding: 0; | |
color: white; /* Set text color to white */ | |
} | |
.sidebar .sidebar-content { | |
background: linear-gradient(to bottom, #4B0082, #1E90FF); /* Purple to blue gradient */ | |
border-radius: 10px; | |
padding: 10px; | |
color: #ffffff; /* Set sidebar text color to white */ | |
} | |
.stSidebar h2 { | |
color: #ffffff; | |
text-align: center; | |
font-size: 20px; | |
font-weight: bold; | |
text-shadow: 2px 2px #000000; | |
} | |
.stButton > button { | |
background-color: #ffcc00; /* Bold yellow */ | |
color: #0033a0; /* Button text color */ | |
border: none; | |
border-radius: 5px; | |
padding: 10px 15px; | |
font-size: 16px; | |
transition: background-color 0.3s ease; | |
} | |
.stButton > button:hover { | |
background-color: #ffc107; /* Brighter yellow */ | |
box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.2); | |
} | |
.stMarkdown h1, .stMarkdown h2, .stMarkdown h3 { | |
color: #ffffff; /* Set headings color to white */ | |
text-shadow: 2px 2px #000000; /* Add shadow for better visibility */ | |
} | |
.stMetric { | |
color: #FFFFFF !important; /* Make metric text white */ | |
border: none; /* Remove any default borders */ | |
padding: 5px; /* Add padding for better spacing */ | |
font-size: 1.2em; /* Slightly increase font size */ | |
text-align: center; /* Center-align the metric text */ | |
} | |
.block-container { | |
border-radius: 10px; | |
padding: 20px; | |
background-color: rgba(0, 0, 0, 0.8); /* Dark semi-transparent background */ | |
color: #ffffff; /* Ensure text inside the container is white */ | |
} | |
.dataframe { | |
background-color: rgba(255, 255, 255, 0.1); /* Transparent table background */ | |
color: #ffffff; /* Table text color */ | |
border-radius: 10px; | |
} | |
.stPlotlyChart { | |
background-color: rgba(0, 0, 0, 0.8); /* Match dark theme */ | |
padding: 10px; | |
border-radius: 10px; | |
box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.5); | |
} | |
.styled-table { | |
width: 100%; | |
border-collapse: collapse; | |
margin: 25px 0; | |
font-size: 18px; | |
text-align: left; | |
border-radius: 5px 5px 0 0; | |
overflow: hidden; | |
color: #ffffff; /* Table text color */ | |
} | |
.styled-table th, .styled-table td { | |
padding: 12px 15px; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
team_logo_paths = { | |
"Cleveland Cavaliers": "Clevelan-Cavaliers-logo-2022.png", | |
"Atlanta Hawks": "nba-atlanta-hawks-logo.png", | |
"Boston Celtics": "nba-boston-celtics-logo.png", | |
"Brooklyn Nets": "nba-brooklyn-nets-logo.png", | |
"Charlotte Hornets": "nba-charlotte-hornets-logo.png", | |
"Chicago Bulls": "nba-chicago-bulls-logo.png", | |
"Dallas Mavericks": "nba-dallas-mavericks-logo.png", | |
"Denver Nuggets": "nba-denver-nuggets-logo-2018.png", | |
"Detroit Pistons": "nba-detroit-pistons-logo.png", | |
"Golden State Warriors": "nba-golden-state-warriors-logo-2020.png", | |
"Houston Rockets": "nba-houston-rockets-logo-2020.png", | |
"Indiana Pacers": "nba-indiana-pacers-logo.png", | |
"LA Clippers": "nba-la-clippers-logo.png", | |
"Los Angeles Lakers": "nba-los-angeles-lakers-logo.png", | |
"Memphis Grizzlies": "nba-memphis-grizzlies-logo.png", | |
"Miami Heat": "nba-miami-heat-logo.png", | |
"Milwaukee Bucks": "nba-milwaukee-bucks-logo.png", | |
"Minnesota Timberwolves": "nba-minnesota-timberwolves-logo.png", | |
"New Orleans Pelicans": "nba-new-orleans-pelicans-logo.png", | |
"New York Knicks": "nba-new-york-knicks-logo.png", | |
"Oklahoma City Thunder": "nba-oklahoma-city-thunder-logo.png", | |
"Orlando Magic": "nba-orlando-magic-logo.png", | |
"Philadelphia 76ers": "nba-philadelphia-76ers-logo.png", | |
"Phoenix Suns": "nba-phoenix-suns-logo.png", | |
"Portland Trail Blazers": "nba-portland-trail-blazers-logo.png", | |
"Sacramento Kings": "nba-sacramento-kings-logo.png", | |
"San Antonio Spurs": "nba-san-antonio-spurs-logo.png", | |
"Toronto Raptors": "nba-toronto-raptors-logo-2020.png", | |
"Utah Jazz": "nba-utah-jazz-logo.png", | |
"Washington Wizards": "nba-washington-wizards-logo.png", | |
} | |
# Mapping for position to numeric values | |
position_mapping = { | |
"PG": 1.0, # Point Guard | |
"SG": 2.0, # Shooting Guard | |
"SF": 3.0, # Small Forward | |
"PF": 4.0, # Power Forward | |
"C": 5.0, # Center | |
} | |
# Predefined injury types | |
injury_types = [ | |
"foot fracture injury", | |
"hip flexor surgery injury", | |
"calf strain injury", | |
"quad injury injury", | |
"shoulder sprain injury", | |
"foot sprain injury", | |
"torn rotator cuff injury injury", | |
"torn mcl injury", | |
"hip flexor strain injury", | |
"fractured leg injury", | |
"sprained mcl injury", | |
"ankle sprain injury", | |
"hamstring injury injury", | |
"meniscus tear injury", | |
"torn hamstring injury", | |
"dislocated shoulder injury", | |
"ankle fracture injury", | |
"fractured hand injury", | |
"bone spurs injury", | |
"acl tear injury", | |
"hip labrum injury", | |
"back surgery injury", | |
"arm injury injury", | |
"torn shoulder labrum injury", | |
"lower back spasm injury" | |
] | |
# Injury average days dictionary | |
average_days_injured = { | |
"foot fracture injury": 207.666667, | |
"hip flexor surgery injury": 256.000000, | |
"calf strain injury": 236.000000, | |
"quad injury injury": 283.000000, | |
"shoulder sprain injury": 259.500000, | |
"foot sprain injury": 294.000000, | |
"torn rotator cuff injury injury": 251.500000, | |
"torn mcl injury": 271.000000, | |
"hip flexor strain injury": 253.000000, | |
"fractured leg injury": 250.250000, | |
"sprained mcl injury": 228.666667, | |
"ankle sprain injury": 231.333333, | |
"hamstring injury injury": 220.000000, | |
"meniscus tear injury": 201.250000, | |
"torn hamstring injury": 187.666667, | |
"dislocated shoulder injury": 269.000000, | |
"ankle fracture injury": 114.500000, | |
"fractured hand injury": 169.142857, | |
"bone spurs injury": 151.500000, | |
"acl tear injury": 268.000000, | |
"hip labrum injury": 247.500000, | |
"back surgery injury": 215.800000, | |
"arm injury injury": 303.666667, | |
"torn shoulder labrum injury": 195.666667, | |
"lower back spasm injury": 234.000000, | |
} | |
# Load player dataset | |
def load_player_data(): | |
return pd.read_csv("player_data.csv") | |
# Load Random Forest model | |
def load_rf_model(): | |
return joblib.load("rf_injury_change_model.pkl") | |
# Main Streamlit app | |
def main(): | |
st.title("NBA Player Performance Predictor ๐") | |
st.write( | |
""" | |
Predict how a player's performance metrics (e.g., points, rebounds, assists) might change | |
if a hypothetical injury occurs, based on their position and other factors. | |
""" | |
) | |
# Load player data | |
player_data = load_player_data() | |
rf_model = load_rf_model() | |
st.sidebar.markdown( | |
""" | |
<div style="padding: 10px; background: linear-gradient(to right, #6a11cb, #2575fc); color: white; border-radius: 10px;"> | |
<h3>Player Details</h3> | |
</div> | |
""", | |
unsafe_allow_html=True | |
) | |
# Dropdown for player selection | |
player_list = sorted(player_data['player_name'].dropna().unique()) | |
player_name = st.sidebar.selectbox("Select Player", player_list) | |
if player_name: | |
# Retrieve player details | |
player_row = player_data[player_data['player_name'] == player_name] | |
team_name = player_row.iloc[0]['team_abbreviation'] | |
position = player_row.iloc[0]['position'] | |
if not player_row.empty: | |
position = player_row.iloc[0]['position'] | |
position_numeric = position_mapping.get(position, 0) | |
st.sidebar.write(f"**Position**: {position} (Numeric: {position_numeric})") | |
# Default values for features | |
stats_columns = ['age', 'player_height', 'player_weight'] | |
default_stats = { | |
stat: player_row.iloc[0][stat] if stat in player_row.columns else 0 | |
for stat in stats_columns | |
} | |
# Allow manual adjustment of stats | |
for stat in default_stats.keys(): | |
default_stats[stat] = st.sidebar.number_input(f"{stat}", value=default_stats[stat]) | |
# Injury details | |
injury_type = st.sidebar.selectbox("Select Hypothetical Injury", injury_types) | |
# Replace slider with default average based on injury type | |
default_days_injured = average_days_injured[injury_type] or 30 # Use 30 if None | |
days_injured = st.sidebar.slider( | |
"Estimated Days Injured", | |
0, | |
365, | |
int(default_days_injured), | |
help=f"Default days for {injury_type}: {int(default_days_injured) if default_days_injured else 'N/A'}" | |
) | |
injury_occurrences = st.sidebar.number_input("Injury Occurrences", min_value=0, value=1) | |
# Prepare input data | |
input_data = pd.DataFrame([{ | |
"days_injured": days_injured, | |
"injury_occurrences": injury_occurrences, | |
"position": position_numeric, | |
"injury_type": injury_type, # Include the selected injury type | |
**default_stats | |
}]) | |
# Encode injury type | |
input_data["injury_type"] = pd.factorize(input_data["injury_type"])[0] | |
# Load Random Forest model | |
try: | |
rf_model = load_rf_model() | |
# Align input data with the model's feature names | |
expected_features = rf_model.feature_names_in_ | |
input_data = input_data.reindex(columns=rf_model.feature_names_in_, fill_value=0) | |
# Predict and display results | |
# Predict and display results | |
if st.sidebar.button("Predict ๐ฎ"): | |
predictions = rf_model.predict(input_data) | |
prediction_columns = ["Predicted Change in PTS", "Predicted Change in REB", "Predicted Change in AST"] | |
st.subheader("Predicted Post-Injury Performance") | |
st.write("Based on the inputs, here are the predicted metrics:") | |
styled_table = pd.DataFrame(predictions, columns=prediction_columns).style.set_table_attributes('class="styled-table"') | |
st.write(styled_table.to_html(), unsafe_allow_html=True) | |
# Plot predictions | |
prediction_df = pd.DataFrame(predictions, columns=prediction_columns) | |
fig = go.Figure() | |
for col in prediction_columns: | |
fig.add_trace(go.Bar( | |
x=[col], | |
y=prediction_df[col], | |
name=col, | |
marker=dict(color=px.colors.qualitative.Plotly[prediction_columns.index(col)]) | |
)) | |
fig.update_layout( | |
title="Predicted Performance Changes", | |
xaxis_title="Metrics", | |
yaxis_title="Change Value", | |
template="plotly_dark", | |
showlegend=True | |
) | |
st.plotly_chart(fig) | |
except FileNotFoundError: | |
st.error("Model file not found.") | |
except ValueError as e: | |
st.error(f"Error during prediction: {e}") | |
else: | |
st.sidebar.error("Player details not found in the dataset.") | |
else: | |
st.sidebar.error("Please select a player to view details.") | |
st.divider() | |
st.header("Player Overview") | |
col1, col2 = st.columns([1, 2]) | |
with col1: | |
st.subheader("Player Details") | |
st.markdown(f""" | |
<div style="margin-bottom: 20px;"> | |
<div style="font-size: 1em; color: white; margin-bottom: 5px;">Age</div> | |
<div style="font-size: 2em; color: white; font-weight: bold;">{default_stats['age']}</div> | |
</div> | |
<div style="margin-bottom: 20px;"> | |
<div style="font-size: 1em; color: white; margin-bottom: 5px;">Height (cm)</div> | |
<div style="font-size: 2em; color: white; font-weight: bold;">{round(default_stats['player_height'], 2)}</div> | |
</div> | |
<div style="margin-bottom: 20px;"> | |
<div style="font-size: 1em; color: white; margin-bottom: 5px;">Weight (kg)</div> | |
<div style="font-size: 2em; color: white; font-weight: bold;">{round(default_stats['player_weight'], 2)}</div> | |
</div> | |
""", unsafe_allow_html=True) | |
with col2: | |
# Display team logo | |
if team_name in team_logo_paths: | |
logo_path = team_logo_paths[team_name] | |
try: | |
logo_image = Image.open(logo_path) | |
st.image(logo_image, caption=f"{team_name} Logo", use_container_width=True) | |
except FileNotFoundError: | |
st.error(f"Logo for {team_name} not found.") | |
# Graphs for PPG, AST, and REB | |
st.divider() | |
st.header("Player Performance Graphs") | |
if st.button("Show Performance Graphs"): | |
# Filter data for the selected player | |
player_data_filtered = player_data[player_data["player_name"] == player_name].sort_values(by="season") | |
# Ensure all seasons are included | |
all_seasons = pd.Series(range(player_data["season"].min(), player_data["season"].max() + 1)) | |
player_data_filtered = ( | |
pd.DataFrame({"season": all_seasons}) | |
.merge(player_data_filtered, on="season", how="left") | |
) | |
if not player_data_filtered.empty: | |
# PPG Graph | |
fig_ppg = px.line( | |
player_data_filtered, | |
x="season", | |
y="pts", | |
title=f"{player_name}: Points Per Game (PPG) Over Seasons", | |
labels={"pts": "Points Per Game (PPG)", "season": "Season"}, | |
markers=True | |
) | |
fig_ppg.update_layout(template="plotly_white") | |
# AST Graph | |
fig_ast = px.line( | |
player_data_filtered, | |
x="season", | |
y="ast", | |
title=f"{player_name}: Assists Per Game (AST) Over Seasons", | |
labels={"ast": "Assists Per Game (AST)", "season": "Season"}, | |
markers=True | |
) | |
fig_ast.update_layout(template="plotly_white") | |
# REB Graph | |
fig_reb = px.line( | |
player_data_filtered, | |
x="season", | |
y="reb", | |
title=f"{player_name}: Rebounds Per Game (REB) Over Seasons", | |
labels={"reb": "Rebounds Per Game (REB)", "season": "Season"}, | |
markers=True | |
) | |
fig_reb.update_layout(template="plotly_white") | |
# Display graphs | |
st.plotly_chart(fig_ppg, use_container_width=True) | |
st.plotly_chart(fig_ast, use_container_width=True) | |
st.plotly_chart(fig_reb, use_container_width=True) | |
else: | |
st.error("No data available for the selected player.") | |
# Footer | |
st.divider() | |
st.markdown(""" | |
### About This Tool | |
This application predicts how injuries might impact an NBA player's performance using machine learning models. Data is based on historical player stats and injuries. | |
""") | |
if __name__ == "__main__": | |
main() |