Spaces:
Sleeping
Sleeping
deployed v1.0.0
Browse files- app.py +141 -37
- models/xgboost.pkl +0 -0
app.py
CHANGED
@@ -1,57 +1,161 @@
|
|
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
-
from huggingface_hub import hf_hub_download
|
4 |
import joblib
|
5 |
-
|
6 |
-
|
|
|
|
|
7 |
|
8 |
-
def
|
9 |
# 从Hugging Face Hub下载模型文件
|
10 |
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
11 |
-
# 加载模型
|
12 |
model = joblib.load(model_path)
|
13 |
return model
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
new_data = pd.read_csv(uploaded_file)
|
18 |
-
X_new = new_data[features]
|
19 |
-
return X_new, new_data['point_victor']
|
20 |
|
21 |
-
def
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
if uploaded_file is not None:
|
28 |
-
# 指定特征列
|
29 |
-
features = [
|
30 |
-
'elapsed_time', 'set_no', 'game_no', 'game_advantage', 'set_advantage',
|
31 |
-
'p1_cumulative_distance', 'p2_cumulative_distance',
|
32 |
-
'p1_max_continuous_scoring', 'p2_max_continuous_scoring',
|
33 |
-
'p1_total_errors', 'p2_total_errors'
|
34 |
-
]
|
35 |
-
# 指定Hugging Face Hub上模型的仓库ID和文件名
|
36 |
repo_id = "Nagi-ovo/Momentum-XGboost"
|
37 |
filename = "xgboost.pkl"
|
38 |
|
39 |
-
#
|
40 |
-
xgb_model =
|
|
|
41 |
|
42 |
-
#
|
43 |
-
|
|
|
|
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
if __name__ == '__main__':
|
57 |
main()
|
|
|
1 |
+
from huggingface_hub import hf_hub_download
|
2 |
import streamlit as st
|
3 |
import pandas as pd
|
|
|
4 |
import joblib
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import seaborn as sns
|
7 |
+
from scipy.interpolate import make_interp_spline
|
8 |
+
import numpy as np
|
9 |
|
10 |
+
def load_model(repo_id, filename):
|
11 |
# 从Hugging Face Hub下载模型文件
|
12 |
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
13 |
+
# 使用joblib加载模型
|
14 |
model = joblib.load(model_path)
|
15 |
return model
|
16 |
|
17 |
+
# 设置Seaborn的风格
|
18 |
+
sns.set(style="whitegrid")
|
|
|
|
|
|
|
19 |
|
20 |
+
def preprocess_and_predict(model, data, match_id, features):
|
21 |
+
# 筛选出特定match_id的数据
|
22 |
+
specific_match_data = data[data['match_id'] == match_id]
|
23 |
+
|
24 |
+
if specific_match_data.empty:
|
25 |
+
st.write(f"No data found for match_id {match_id}.")
|
26 |
+
return None, None
|
27 |
+
else:
|
28 |
+
specific_match_data = specific_match_data.sort_values('elapsed_time')
|
29 |
+
|
30 |
+
# 预处理数据集
|
31 |
+
X_specific_match = specific_match_data[features]
|
32 |
+
|
33 |
+
# 使用模型进行概率预测
|
34 |
+
positive_class_probabilities = model.predict_proba(X_specific_match)[:, 1]
|
35 |
|
36 |
+
return specific_match_data, positive_class_probabilities
|
37 |
+
|
38 |
+
def plot_results(specific_match_data, positive_class_probabilities, match_id, show_true_label=True, show_predicted_probability=True, show_momentum=True):
|
39 |
+
plt.figure(figsize=(14, 7))
|
40 |
+
|
41 |
+
if show_predicted_probability:
|
42 |
+
# 绘制预测概率的线图
|
43 |
+
sns.lineplot(
|
44 |
+
x=specific_match_data['elapsed_time'],
|
45 |
+
y=positive_class_probabilities,
|
46 |
+
marker='o',
|
47 |
+
linestyle='-',
|
48 |
+
color='blue',
|
49 |
+
label='Predicted Probability'
|
50 |
+
)
|
51 |
|
52 |
+
if show_true_label:
|
53 |
+
# 绘制真实标签的散点图
|
54 |
+
sns.scatterplot(
|
55 |
+
x=specific_match_data['elapsed_time'],
|
56 |
+
y=(specific_match_data['point_victor'] - 1),
|
57 |
+
color='red',
|
58 |
+
label='True Label',
|
59 |
+
s=60
|
60 |
+
)
|
61 |
+
|
62 |
+
if show_momentum:
|
63 |
+
# 计算并绘制Momentum
|
64 |
+
window_size = 3
|
65 |
+
padded_probabilities = np.pad(positive_class_probabilities, (window_size//2, window_size//2), 'edge')
|
66 |
+
momentum = np.convolve(padded_probabilities, np.ones(window_size)/window_size, mode='same')[:len(specific_match_data)]
|
67 |
+
|
68 |
+
# 创建平滑曲线
|
69 |
+
X_smooth = np.linspace(specific_match_data['elapsed_time'].min(), specific_match_data['elapsed_time'].max(), 300)
|
70 |
+
spline = make_interp_spline(specific_match_data['elapsed_time'], momentum, k=3) # k是平滑曲线的强度
|
71 |
+
momentum_smooth = spline(X_smooth)
|
72 |
+
|
73 |
+
# 绘制Momentum的平滑曲线下面积,半透明
|
74 |
+
plt.fill_between(X_smooth, momentum_smooth, color='green', alpha=0.3)
|
75 |
+
# 也可以选择绘制平滑曲线的边界,如果需要
|
76 |
+
plt.plot(X_smooth, momentum_smooth, color='green', label='Momentum')
|
77 |
+
|
78 |
+
# 标记set_no和game_no变化的时刻
|
79 |
+
for i in range(1, len(specific_match_data)):
|
80 |
+
if specific_match_data['game_no'].iloc[i] != specific_match_data['game_no'].iloc[i-1]:
|
81 |
+
plt.axvline(
|
82 |
+
x=specific_match_data['elapsed_time'].iloc[i],
|
83 |
+
color='gray',
|
84 |
+
linestyle='--',
|
85 |
+
lw=2
|
86 |
+
)
|
87 |
+
if specific_match_data['set_no'].iloc[i] != specific_match_data['set_no'].iloc[i-1]:
|
88 |
+
plt.axvline(
|
89 |
+
x=specific_match_data['elapsed_time'].iloc[i],
|
90 |
+
color='red',
|
91 |
+
linestyle='-.',
|
92 |
+
lw=2
|
93 |
+
)
|
94 |
+
|
95 |
+
plt.title(f'Predicted Probability, Momentum, and True Label Over Time for Match {match_id}')
|
96 |
+
plt.xlabel('Elapsed Time')
|
97 |
+
plt.ylabel('Probability / True Label / Momentum')
|
98 |
+
plt.grid(True)
|
99 |
+
plt.legend()
|
100 |
+
st.pyplot(plt)
|
101 |
+
|
102 |
+
def main():
|
103 |
+
st.title('Momentum Catcher')
|
104 |
+
|
105 |
+
uploaded_file = st.file_uploader("Upload your input CSV data", type="csv")
|
106 |
+
match_id_input = st.text_input("Enter the match_id you want to analyze", "2023-wimbledon-1301")
|
107 |
+
|
108 |
if uploaded_file is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
repo_id = "Nagi-ovo/Momentum-XGboost"
|
110 |
filename = "xgboost.pkl"
|
111 |
|
112 |
+
# 加载模型
|
113 |
+
xgb_model = load_model(repo_id, filename)
|
114 |
+
features = ['PAI_diff', 'normalized_rally', 'is_game_point']
|
115 |
|
116 |
+
# 读取上传的CSV文件
|
117 |
+
new_data = pd.read_csv(uploaded_file)
|
118 |
+
new_data['is_game_point'] = abs(new_data['p1_facing_game_point'] - new_data['p2_facing_game_point'])
|
119 |
+
new_data['PAI_diff'] = new_data['p1_PAI'] - new_data['p2_PAI']
|
120 |
|
121 |
+
specific_match_data, positive_class_probabilities = preprocess_and_predict(xgb_model, new_data, match_id_input, features)
|
122 |
+
|
123 |
+
if specific_match_data is not None:
|
124 |
+
# 允许用户选择是否要观察特定的set和game
|
125 |
+
observe_specific = st.checkbox("Observe specific set(s) and game(s)", False)
|
126 |
+
|
127 |
+
if observe_specific:
|
128 |
+
# 用户选择观察多盘
|
129 |
+
unique_sets = specific_match_data['set_no'].unique()
|
130 |
+
selected_sets = st.multiselect('Select set(s) to observe', unique_sets, default=unique_sets)
|
131 |
+
|
132 |
+
# 基于选定的set_no,选择观察一局或多局
|
133 |
+
if selected_sets:
|
134 |
+
filtered_data = specific_match_data[specific_match_data['set_no'].isin(selected_sets)]
|
135 |
+
unique_games = filtered_data['game_no'].unique()
|
136 |
+
selected_games = st.multiselect('Select games to observe (default is all games)', unique_games, default=unique_games)
|
137 |
+
|
138 |
+
# 进一步筛选数据以仅包含选定的game_no
|
139 |
+
if selected_games:
|
140 |
+
filtered_data = filtered_data[filtered_data['game_no'].isin(selected_games)]
|
141 |
+
else:
|
142 |
+
filtered_data = specific_match_data # 如果没有选择任何set,显示全部数据
|
143 |
+
else:
|
144 |
+
filtered_data = specific_match_data
|
145 |
+
selected_games = specific_match_data['game_no'].unique() # 默认选择所有游戏
|
146 |
+
|
147 |
+
# 显示选项复选框
|
148 |
+
show_true_label = st.checkbox("Show True Label", True)
|
149 |
+
show_predicted_probability = st.checkbox("Show Predicted Probability", True)
|
150 |
+
show_momentum = st.checkbox("Show Momentum", True)
|
151 |
+
|
152 |
+
# 如果筛选后的数据不为空,则绘制图表
|
153 |
+
if not filtered_data.empty:
|
154 |
+
# 计算筛选数据的预测概率
|
155 |
+
filtered_positive_class_probabilities = xgb_model.predict_proba(filtered_data[features])[:, 1]
|
156 |
+
plot_results(filtered_data, filtered_positive_class_probabilities, match_id_input, show_true_label, show_predicted_probability, show_momentum)
|
157 |
+
else:
|
158 |
+
st.write("No data available for the selected set(s) and games. Please select again.")
|
159 |
+
|
160 |
if __name__ == '__main__':
|
161 |
main()
|
models/xgboost.pkl
DELETED
Binary file (271 kB)
|
|