|
from huggingface_hub import hf_hub_download |
|
import streamlit as st |
|
import pandas as pd |
|
import joblib |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
from scipy.interpolate import make_interp_spline |
|
import numpy as np |
|
import os |
|
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
|
def load_model(repo_id, filename): |
|
|
|
model_path = hf_hub_download(repo_id=repo_id, filename=filename, use_auth_token=hf_token) |
|
|
|
model = joblib.load(model_path) |
|
return model |
|
|
|
|
|
sns.set(style="whitegrid") |
|
|
|
def preprocess_and_predict(model, data, match_id, features): |
|
|
|
specific_match_data = data[data['match_id'] == match_id] |
|
specific_match_data = specific_match_data.drop_duplicates(subset=['elapsed_time'], keep='first') |
|
specific_match_data = specific_match_data[specific_match_data['server'] != 0] |
|
|
|
if specific_match_data.empty: |
|
st.write(f"No data found for match_id {match_id}.") |
|
return None, None |
|
else: |
|
specific_match_data = specific_match_data.sort_values('elapsed_time') |
|
|
|
|
|
X_specific_match = specific_match_data[features] |
|
|
|
|
|
positive_class_probabilities = model.predict_proba(X_specific_match)[:, 1] |
|
|
|
return specific_match_data, positive_class_probabilities |
|
|
|
def plot_results(specific_match_data, positive_class_probabilities, match_id, show_true_label=True, show_predicted_probability=True, show_momentum=True): |
|
plt.figure(figsize=(14, 7)) |
|
|
|
if show_predicted_probability: |
|
|
|
sns.lineplot( |
|
x=specific_match_data['elapsed_time'], |
|
y=positive_class_probabilities, |
|
marker='o', |
|
linestyle='-', |
|
color='blue', |
|
label='Predicted Probability' |
|
) |
|
|
|
if show_true_label: |
|
|
|
sns.scatterplot( |
|
x=specific_match_data['elapsed_time'], |
|
y=(specific_match_data['point_victor'] - 1), |
|
color='red', |
|
label='True Label', |
|
s=60 |
|
) |
|
|
|
if show_momentum: |
|
|
|
adjusted_probabilities = [] |
|
for i in range(1, len(positive_class_probabilities)): |
|
|
|
if i < len(positive_class_probabilities) - 1: |
|
next_prob = 0.5 |
|
else: |
|
next_prob = positive_class_probabilities[i] |
|
adjusted_probabilities.append((positive_class_probabilities[i-1] + positive_class_probabilities[i] + next_prob) / 3) |
|
|
|
|
|
adjusted_probabilities.insert(0, (positive_class_probabilities[0] + 0.5) / 2) |
|
|
|
|
|
momentum = np.array(adjusted_probabilities[:len(specific_match_data)]) |
|
|
|
X_smooth = np.linspace(specific_match_data['elapsed_time'].min(), specific_match_data['elapsed_time'].max(), 300) |
|
spline = make_interp_spline(specific_match_data['elapsed_time'], momentum, k=3) |
|
momentum_smooth = spline(X_smooth) |
|
|
|
plt.fill_between(X_smooth, momentum_smooth, color='green', alpha=0.3) |
|
plt.plot(X_smooth, momentum_smooth, color='green', label='Momentum') |
|
|
|
|
|
for i in range(1, len(specific_match_data)): |
|
if specific_match_data['game_no'].iloc[i] != specific_match_data['game_no'].iloc[i-1]: |
|
plt.axvline( |
|
x=specific_match_data['elapsed_time'].iloc[i], |
|
color='gray', |
|
linestyle='--', |
|
lw=2 |
|
) |
|
if specific_match_data['set_no'].iloc[i] != specific_match_data['set_no'].iloc[i-1]: |
|
plt.axvline( |
|
x=specific_match_data['elapsed_time'].iloc[i], |
|
color='red', |
|
linestyle='-.', |
|
lw=2 |
|
) |
|
|
|
plt.title(f'Predicted Probability, Momentum, and True Label Over Time for Match {match_id}') |
|
plt.xlabel('Elapsed Time') |
|
plt.ylabel('Probability / True Label / Momentum') |
|
plt.grid(True) |
|
plt.legend() |
|
st.pyplot(plt) |
|
|
|
def main(): |
|
st.title('Momentum Catcher') |
|
|
|
st.markdown(""" |
|
To get started, you can find sample data available for download at |
|
[Hugging Face Spaces](https://huggingface.co/spaces/Nagi-ovo/Tennis-Momentum-Tracker/tree/main/data). |
|
This data can be used directly in this application to analyze tennis match momentum. |
|
""") |
|
|
|
uploaded_file = st.file_uploader("Upload your input CSV data", type="csv") |
|
|
|
|
|
if uploaded_file is not None: |
|
new_data = pd.read_csv(uploaded_file) |
|
new_data.dropna() |
|
|
|
|
|
unique_match_ids = new_data['match_id'].unique() |
|
|
|
|
|
match_id_input = st.selectbox("Select the match_id you want to analyze", unique_match_ids) |
|
|
|
|
|
repo_id = "Nagi-ovo/Momentum-XGboost" |
|
filename = "xgboost.pkl" |
|
|
|
|
|
xgb_model = load_model(repo_id, filename) |
|
features = ['PAI_diff', 'normalized_rally', 'is_game_point'] |
|
|
|
new_data['is_game_point'] = abs(new_data['p1_facing_game_point'] - new_data['p2_facing_game_point']) |
|
new_data['PAI_diff'] = new_data['p1_PAI'] - new_data['p2_PAI'] |
|
|
|
specific_match_data, positive_class_probabilities = preprocess_and_predict(xgb_model, new_data, match_id_input, features) |
|
|
|
if specific_match_data is not None: |
|
|
|
observe_specific = st.checkbox("Observe specific set(s) and game(s)", False) |
|
|
|
if observe_specific: |
|
|
|
unique_sets = specific_match_data['set_no'].unique() |
|
selected_sets = st.multiselect('Select set(s) to observe', unique_sets, default=unique_sets) |
|
|
|
|
|
if selected_sets: |
|
filtered_data = specific_match_data[specific_match_data['set_no'].isin(selected_sets)] |
|
unique_games = filtered_data['game_no'].unique() |
|
selected_games = st.multiselect('Select games to observe (default is all games)', unique_games, default=unique_games) |
|
|
|
|
|
if selected_games: |
|
filtered_data = filtered_data[filtered_data['game_no'].isin(selected_games)] |
|
else: |
|
filtered_data = specific_match_data |
|
else: |
|
filtered_data = specific_match_data |
|
selected_games = specific_match_data['game_no'].unique() |
|
|
|
|
|
show_true_label = st.checkbox("Show True Label", True) |
|
show_predicted_probability = st.checkbox("Show Predicted Probability", True) |
|
show_momentum = st.checkbox("Show Momentum", True) |
|
|
|
|
|
if not filtered_data.empty: |
|
|
|
filtered_positive_class_probabilities = xgb_model.predict_proba(filtered_data[features])[:, 1] |
|
plot_results(filtered_data, filtered_positive_class_probabilities, match_id_input, show_true_label, show_predicted_probability, show_momentum) |
|
else: |
|
st.write("No data available for the selected set(s) and games. Please select again.") |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|