from huggingface_hub import hf_hub_download import streamlit as st import pandas as pd import joblib import matplotlib.pyplot as plt import seaborn as sns from scipy.interpolate import make_interp_spline import numpy as np import os hf_token = os.getenv("HF_TOKEN") def load_model(repo_id, filename): # 使用环境变量中的HF_TOKEN下载模型文件 model_path = hf_hub_download(repo_id=repo_id, filename=filename, use_auth_token=hf_token) # 使用joblib加载模型 model = joblib.load(model_path) return model # 设置Seaborn的风格 sns.set(style="whitegrid") def preprocess_and_predict(model, data, match_id, features): # 筛选出特定match_id的数据 specific_match_data = data[data['match_id'] == match_id] specific_match_data = specific_match_data.drop_duplicates(subset=['elapsed_time'], keep='first') specific_match_data = specific_match_data[specific_match_data['server'] != 0] if specific_match_data.empty: st.write(f"No data found for match_id {match_id}.") return None, None else: specific_match_data = specific_match_data.sort_values('elapsed_time') # 预处理数据集 X_specific_match = specific_match_data[features] # 使用模型进行概率预测 positive_class_probabilities = model.predict_proba(X_specific_match)[:, 1] return specific_match_data, positive_class_probabilities def plot_results(specific_match_data, positive_class_probabilities, match_id, show_true_label=True, show_predicted_probability=True, show_momentum=True): plt.figure(figsize=(14, 7)) if show_predicted_probability: # 绘制预测概率的线图 sns.lineplot( x=specific_match_data['elapsed_time'], y=positive_class_probabilities, marker='o', linestyle='-', color='blue', label='Predicted Probability' ) if show_true_label: # 绘制真实标签的散点图 sns.scatterplot( x=specific_match_data['elapsed_time'], y=(specific_match_data['point_victor'] - 1), color='red', label='True Label', s=60 ) if show_momentum: # 修改动能计算逻辑,以确保长度一致 adjusted_probabilities = [] for i in range(1, len(positive_class_probabilities)): # 使用当前点和前一个点的概率值,假定下一个点的概率为0.5 if i < len(positive_class_probabilities) - 1: next_prob = 0.5 # 对于除最后一个点外的所有点,假定下一个点的概率为0.5 else: next_prob = positive_class_probabilities[i] # 对于最后一个点,使用其本身的概率 adjusted_probabilities.append((positive_class_probabilities[i-1] + positive_class_probabilities[i] + next_prob) / 3) # 对于第一个点,我们可以选择使用它自己的概率,因为没有前一个点 adjusted_probabilities.insert(0, (positive_class_probabilities[0] + 0.5) / 2) # 在开始处插入 # 确保adjusted_probabilities的长度与specific_match_data['elapsed_time']一致 momentum = np.array(adjusted_probabilities[:len(specific_match_data)]) X_smooth = np.linspace(specific_match_data['elapsed_time'].min(), specific_match_data['elapsed_time'].max(), 300) spline = make_interp_spline(specific_match_data['elapsed_time'], momentum, k=3) momentum_smooth = spline(X_smooth) plt.fill_between(X_smooth, momentum_smooth, color='green', alpha=0.3) plt.plot(X_smooth, momentum_smooth, color='green', label='Momentum') # 标记set_no和game_no变化的时刻 for i in range(1, len(specific_match_data)): if specific_match_data['game_no'].iloc[i] != specific_match_data['game_no'].iloc[i-1]: plt.axvline( x=specific_match_data['elapsed_time'].iloc[i], color='gray', linestyle='--', lw=2 ) if specific_match_data['set_no'].iloc[i] != specific_match_data['set_no'].iloc[i-1]: plt.axvline( x=specific_match_data['elapsed_time'].iloc[i], color='red', linestyle='-.', lw=2 ) plt.title(f'Predicted Probability, Momentum, and True Label Over Time for Match {match_id}') plt.xlabel('Elapsed Time') plt.ylabel('Probability / True Label / Momentum') plt.grid(True) plt.legend() st.pyplot(plt) def main(): st.title('Momentum Catcher') st.markdown(""" To get started, you can find sample data available for download at [Hugging Face Spaces](https://huggingface.co/spaces/Nagi-ovo/Tennis-Momentum-Tracker/tree/main/data). This data can be used directly in this application to analyze tennis match momentum. """) uploaded_file = st.file_uploader("Upload your input CSV data", type="csv") # match_id_input = st.text_input("Enter the match_id you want to analyze", "2023-wimbledon-1301") if uploaded_file is not None: new_data = pd.read_csv(uploaded_file) new_data.dropna() # 新增:提取所有唯一的match_id unique_match_ids = new_data['match_id'].unique() # 新增:让用户从所有match_id中选择一个 match_id_input = st.selectbox("Select the match_id you want to analyze", unique_match_ids) # 模型信息 repo_id = "Nagi-ovo/Momentum-XGboost" filename = "xgboost.pkl" # 加载模型 xgb_model = load_model(repo_id, filename) features = ['PAI_diff', 'normalized_rally', 'is_game_point'] new_data['is_game_point'] = abs(new_data['p1_facing_game_point'] - new_data['p2_facing_game_point']) new_data['PAI_diff'] = new_data['p1_PAI'] - new_data['p2_PAI'] specific_match_data, positive_class_probabilities = preprocess_and_predict(xgb_model, new_data, match_id_input, features) if specific_match_data is not None: # 允许用户选择是否要观察特定的set和game observe_specific = st.checkbox("Observe specific set(s) and game(s)", False) if observe_specific: # 用户选择观察多盘 unique_sets = specific_match_data['set_no'].unique() selected_sets = st.multiselect('Select set(s) to observe', unique_sets, default=unique_sets) # 基于选定的set_no,选择观察一局或多局 if selected_sets: filtered_data = specific_match_data[specific_match_data['set_no'].isin(selected_sets)] unique_games = filtered_data['game_no'].unique() selected_games = st.multiselect('Select games to observe (default is all games)', unique_games, default=unique_games) # 进一步筛选数据以仅包含选定的game_no if selected_games: filtered_data = filtered_data[filtered_data['game_no'].isin(selected_games)] else: filtered_data = specific_match_data # 如果没有选择任何set,显示全部数据 else: filtered_data = specific_match_data selected_games = specific_match_data['game_no'].unique() # 默认选择所有游戏 # 显示选项复选框 show_true_label = st.checkbox("Show True Label", True) show_predicted_probability = st.checkbox("Show Predicted Probability", True) show_momentum = st.checkbox("Show Momentum", True) # 如果筛选后的数据不为空,则绘制图表 if not filtered_data.empty: # 计算筛选数据的预测概率 filtered_positive_class_probabilities = xgb_model.predict_proba(filtered_data[features])[:, 1] plot_results(filtered_data, filtered_positive_class_probabilities, match_id_input, show_true_label, show_predicted_probability, show_momentum) else: st.write("No data available for the selected set(s) and games. Please select again.") if __name__ == '__main__': main()