Nagi-ovo's picture
deployed v1.0.4
037c01c
from huggingface_hub import hf_hub_download
import streamlit as st
import pandas as pd
import joblib
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.interpolate import make_interp_spline
import numpy as np
import os
hf_token = os.getenv("HF_TOKEN")
def load_model(repo_id, filename):
# 使用环境变量中的HF_TOKEN下载模型文件
model_path = hf_hub_download(repo_id=repo_id, filename=filename, use_auth_token=hf_token)
# 使用joblib加载模型
model = joblib.load(model_path)
return model
# 设置Seaborn的风格
sns.set(style="whitegrid")
def preprocess_and_predict(model, data, match_id, features):
# 筛选出特定match_id的数据
specific_match_data = data[data['match_id'] == match_id]
specific_match_data = specific_match_data.drop_duplicates(subset=['elapsed_time'], keep='first')
specific_match_data = specific_match_data[specific_match_data['server'] != 0]
if specific_match_data.empty:
st.write(f"No data found for match_id {match_id}.")
return None, None
else:
specific_match_data = specific_match_data.sort_values('elapsed_time')
# 预处理数据集
X_specific_match = specific_match_data[features]
# 使用模型进行概率预测
positive_class_probabilities = model.predict_proba(X_specific_match)[:, 1]
return specific_match_data, positive_class_probabilities
def plot_results(specific_match_data, positive_class_probabilities, match_id, show_true_label=True, show_predicted_probability=True, show_momentum=True):
plt.figure(figsize=(14, 7))
if show_predicted_probability:
# 绘制预测概率的线图
sns.lineplot(
x=specific_match_data['elapsed_time'],
y=positive_class_probabilities,
marker='o',
linestyle='-',
color='blue',
label='Predicted Probability'
)
if show_true_label:
# 绘制真实标签的散点图
sns.scatterplot(
x=specific_match_data['elapsed_time'],
y=(specific_match_data['point_victor'] - 1),
color='red',
label='True Label',
s=60
)
if show_momentum:
# 修改动能计算逻辑,以确保长度一致
adjusted_probabilities = []
for i in range(1, len(positive_class_probabilities)):
# 使用当前点和前一个点的概率值,假定下一个点的概率为0.5
if i < len(positive_class_probabilities) - 1:
next_prob = 0.5 # 对于除最后一个点外的所有点,假定下一个点的概率为0.5
else:
next_prob = positive_class_probabilities[i] # 对于最后一个点,使用其本身的概率
adjusted_probabilities.append((positive_class_probabilities[i-1] + positive_class_probabilities[i] + next_prob) / 3)
# 对于第一个点,我们可以选择使用它自己的概率,因为没有前一个点
adjusted_probabilities.insert(0, (positive_class_probabilities[0] + 0.5) / 2) # 在开始处插入
# 确保adjusted_probabilities的长度与specific_match_data['elapsed_time']一致
momentum = np.array(adjusted_probabilities[:len(specific_match_data)])
X_smooth = np.linspace(specific_match_data['elapsed_time'].min(), specific_match_data['elapsed_time'].max(), 300)
spline = make_interp_spline(specific_match_data['elapsed_time'], momentum, k=3)
momentum_smooth = spline(X_smooth)
plt.fill_between(X_smooth, momentum_smooth, color='green', alpha=0.3)
plt.plot(X_smooth, momentum_smooth, color='green', label='Momentum')
# 标记set_no和game_no变化的时刻
for i in range(1, len(specific_match_data)):
if specific_match_data['game_no'].iloc[i] != specific_match_data['game_no'].iloc[i-1]:
plt.axvline(
x=specific_match_data['elapsed_time'].iloc[i],
color='gray',
linestyle='--',
lw=2
)
if specific_match_data['set_no'].iloc[i] != specific_match_data['set_no'].iloc[i-1]:
plt.axvline(
x=specific_match_data['elapsed_time'].iloc[i],
color='red',
linestyle='-.',
lw=2
)
plt.title(f'Predicted Probability, Momentum, and True Label Over Time for Match {match_id}')
plt.xlabel('Elapsed Time')
plt.ylabel('Probability / True Label / Momentum')
plt.grid(True)
plt.legend()
st.pyplot(plt)
def main():
st.title('Momentum Catcher')
st.markdown("""
To get started, you can find sample data available for download at
[Hugging Face Spaces](https://huggingface.co/spaces/Nagi-ovo/Tennis-Momentum-Tracker/tree/main/data).
This data can be used directly in this application to analyze tennis match momentum.
""")
uploaded_file = st.file_uploader("Upload your input CSV data", type="csv")
# match_id_input = st.text_input("Enter the match_id you want to analyze", "2023-wimbledon-1301")
if uploaded_file is not None:
new_data = pd.read_csv(uploaded_file)
new_data.dropna()
# 新增:提取所有唯一的match_id
unique_match_ids = new_data['match_id'].unique()
# 新增:让用户从所有match_id中选择一个
match_id_input = st.selectbox("Select the match_id you want to analyze", unique_match_ids)
# 模型信息
repo_id = "Nagi-ovo/Momentum-XGboost"
filename = "xgboost.pkl"
# 加载模型
xgb_model = load_model(repo_id, filename)
features = ['PAI_diff', 'normalized_rally', 'is_game_point']
new_data['is_game_point'] = abs(new_data['p1_facing_game_point'] - new_data['p2_facing_game_point'])
new_data['PAI_diff'] = new_data['p1_PAI'] - new_data['p2_PAI']
specific_match_data, positive_class_probabilities = preprocess_and_predict(xgb_model, new_data, match_id_input, features)
if specific_match_data is not None:
# 允许用户选择是否要观察特定的set和game
observe_specific = st.checkbox("Observe specific set(s) and game(s)", False)
if observe_specific:
# 用户选择观察多盘
unique_sets = specific_match_data['set_no'].unique()
selected_sets = st.multiselect('Select set(s) to observe', unique_sets, default=unique_sets)
# 基于选定的set_no,选择观察一局或多局
if selected_sets:
filtered_data = specific_match_data[specific_match_data['set_no'].isin(selected_sets)]
unique_games = filtered_data['game_no'].unique()
selected_games = st.multiselect('Select games to observe (default is all games)', unique_games, default=unique_games)
# 进一步筛选数据以仅包含选定的game_no
if selected_games:
filtered_data = filtered_data[filtered_data['game_no'].isin(selected_games)]
else:
filtered_data = specific_match_data # 如果没有选择任何set,显示全部数据
else:
filtered_data = specific_match_data
selected_games = specific_match_data['game_no'].unique() # 默认选择所有游戏
# 显示选项复选框
show_true_label = st.checkbox("Show True Label", True)
show_predicted_probability = st.checkbox("Show Predicted Probability", True)
show_momentum = st.checkbox("Show Momentum", True)
# 如果筛选后的数据不为空,则绘制图表
if not filtered_data.empty:
# 计算筛选数据的预测概率
filtered_positive_class_probabilities = xgb_model.predict_proba(filtered_data[features])[:, 1]
plot_results(filtered_data, filtered_positive_class_probabilities, match_id_input, show_true_label, show_predicted_probability, show_momentum)
else:
st.write("No data available for the selected set(s) and games. Please select again.")
if __name__ == '__main__':
main()