Spaces:
Sleeping
Sleeping
File size: 8,497 Bytes
e2947cc 1aa7396 e2947cc 86d7e74 1aa7396 e2947cc 86d7e74 e2947cc 1aa7396 e2947cc 1aa7396 e2947cc 7456358 e2947cc 1aa7396 e2947cc 1aa7396 e2947cc 639ccd9 e2947cc 639ccd9 e2947cc 037c01c e2947cc 7456358 e2947cc 1aa7396 7456358 2db0224 1aa7396 e2947cc 1aa7396 e2947cc 1aa7396 e2947cc 1aa7396 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
from huggingface_hub import hf_hub_download
import streamlit as st
import pandas as pd
import joblib
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.interpolate import make_interp_spline
import numpy as np
import os
hf_token = os.getenv("HF_TOKEN")
def load_model(repo_id, filename):
# 使用环境变量中的HF_TOKEN下载模型文件
model_path = hf_hub_download(repo_id=repo_id, filename=filename, use_auth_token=hf_token)
# 使用joblib加载模型
model = joblib.load(model_path)
return model
# 设置Seaborn的风格
sns.set(style="whitegrid")
def preprocess_and_predict(model, data, match_id, features):
# 筛选出特定match_id的数据
specific_match_data = data[data['match_id'] == match_id]
specific_match_data = specific_match_data.drop_duplicates(subset=['elapsed_time'], keep='first')
specific_match_data = specific_match_data[specific_match_data['server'] != 0]
if specific_match_data.empty:
st.write(f"No data found for match_id {match_id}.")
return None, None
else:
specific_match_data = specific_match_data.sort_values('elapsed_time')
# 预处理数据集
X_specific_match = specific_match_data[features]
# 使用模型进行概率预测
positive_class_probabilities = model.predict_proba(X_specific_match)[:, 1]
return specific_match_data, positive_class_probabilities
def plot_results(specific_match_data, positive_class_probabilities, match_id, show_true_label=True, show_predicted_probability=True, show_momentum=True):
plt.figure(figsize=(14, 7))
if show_predicted_probability:
# 绘制预测概率的线图
sns.lineplot(
x=specific_match_data['elapsed_time'],
y=positive_class_probabilities,
marker='o',
linestyle='-',
color='blue',
label='Predicted Probability'
)
if show_true_label:
# 绘制真实标签的散点图
sns.scatterplot(
x=specific_match_data['elapsed_time'],
y=(specific_match_data['point_victor'] - 1),
color='red',
label='True Label',
s=60
)
if show_momentum:
# 修改动能计算逻辑,以确保长度一致
adjusted_probabilities = []
for i in range(1, len(positive_class_probabilities)):
# 使用当前点和前一个点的概率值,假定下一个点的概率为0.5
if i < len(positive_class_probabilities) - 1:
next_prob = 0.5 # 对于除最后一个点外的所有点,假定下一个点的概率为0.5
else:
next_prob = positive_class_probabilities[i] # 对于最后一个点,使用其本身的概率
adjusted_probabilities.append((positive_class_probabilities[i-1] + positive_class_probabilities[i] + next_prob) / 3)
# 对于第一个点,我们可以选择使用它自己的概率,因为没有前一个点
adjusted_probabilities.insert(0, (positive_class_probabilities[0] + 0.5) / 2) # 在开始处插入
# 确保adjusted_probabilities的长度与specific_match_data['elapsed_time']一致
momentum = np.array(adjusted_probabilities[:len(specific_match_data)])
X_smooth = np.linspace(specific_match_data['elapsed_time'].min(), specific_match_data['elapsed_time'].max(), 300)
spline = make_interp_spline(specific_match_data['elapsed_time'], momentum, k=3)
momentum_smooth = spline(X_smooth)
plt.fill_between(X_smooth, momentum_smooth, color='green', alpha=0.3)
plt.plot(X_smooth, momentum_smooth, color='green', label='Momentum')
# 标记set_no和game_no变化的时刻
for i in range(1, len(specific_match_data)):
if specific_match_data['game_no'].iloc[i] != specific_match_data['game_no'].iloc[i-1]:
plt.axvline(
x=specific_match_data['elapsed_time'].iloc[i],
color='gray',
linestyle='--',
lw=2
)
if specific_match_data['set_no'].iloc[i] != specific_match_data['set_no'].iloc[i-1]:
plt.axvline(
x=specific_match_data['elapsed_time'].iloc[i],
color='red',
linestyle='-.',
lw=2
)
plt.title(f'Predicted Probability, Momentum, and True Label Over Time for Match {match_id}')
plt.xlabel('Elapsed Time')
plt.ylabel('Probability / True Label / Momentum')
plt.grid(True)
plt.legend()
st.pyplot(plt)
def main():
st.title('Momentum Catcher')
st.markdown("""
To get started, you can find sample data available for download at
[Hugging Face Spaces](https://huggingface.co/spaces/Nagi-ovo/Tennis-Momentum-Tracker/tree/main/data).
This data can be used directly in this application to analyze tennis match momentum.
""")
uploaded_file = st.file_uploader("Upload your input CSV data", type="csv")
# match_id_input = st.text_input("Enter the match_id you want to analyze", "2023-wimbledon-1301")
if uploaded_file is not None:
new_data = pd.read_csv(uploaded_file)
new_data.dropna()
# 新增:提取所有唯一的match_id
unique_match_ids = new_data['match_id'].unique()
# 新增:让用户从所有match_id中选择一个
match_id_input = st.selectbox("Select the match_id you want to analyze", unique_match_ids)
# 模型信息
repo_id = "Nagi-ovo/Momentum-XGboost"
filename = "xgboost.pkl"
# 加载模型
xgb_model = load_model(repo_id, filename)
features = ['PAI_diff', 'normalized_rally', 'is_game_point']
new_data['is_game_point'] = abs(new_data['p1_facing_game_point'] - new_data['p2_facing_game_point'])
new_data['PAI_diff'] = new_data['p1_PAI'] - new_data['p2_PAI']
specific_match_data, positive_class_probabilities = preprocess_and_predict(xgb_model, new_data, match_id_input, features)
if specific_match_data is not None:
# 允许用户选择是否要观察特定的set和game
observe_specific = st.checkbox("Observe specific set(s) and game(s)", False)
if observe_specific:
# 用户选择观察多盘
unique_sets = specific_match_data['set_no'].unique()
selected_sets = st.multiselect('Select set(s) to observe', unique_sets, default=unique_sets)
# 基于选定的set_no,选择观察一局或多局
if selected_sets:
filtered_data = specific_match_data[specific_match_data['set_no'].isin(selected_sets)]
unique_games = filtered_data['game_no'].unique()
selected_games = st.multiselect('Select games to observe (default is all games)', unique_games, default=unique_games)
# 进一步筛选数据以仅包含选定的game_no
if selected_games:
filtered_data = filtered_data[filtered_data['game_no'].isin(selected_games)]
else:
filtered_data = specific_match_data # 如果没有选择任何set,显示全部数据
else:
filtered_data = specific_match_data
selected_games = specific_match_data['game_no'].unique() # 默认选择所有游戏
# 显示选项复选框
show_true_label = st.checkbox("Show True Label", True)
show_predicted_probability = st.checkbox("Show Predicted Probability", True)
show_momentum = st.checkbox("Show Momentum", True)
# 如果筛选后的数据不为空,则绘制图表
if not filtered_data.empty:
# 计算筛选数据的预测概率
filtered_positive_class_probabilities = xgb_model.predict_proba(filtered_data[features])[:, 1]
plot_results(filtered_data, filtered_positive_class_probabilities, match_id_input, show_true_label, show_predicted_probability, show_momentum)
else:
st.write("No data available for the selected set(s) and games. Please select again.")
if __name__ == '__main__':
main()
|