Nagi-ovo commited on
Commit
e2947cc
·
1 Parent(s): 2db0224

deployed v1.0.0

Browse files
Files changed (2) hide show
  1. app.py +141 -37
  2. models/xgboost.pkl +0 -0
app.py CHANGED
@@ -1,57 +1,161 @@
 
1
  import streamlit as st
2
  import pandas as pd
3
- from huggingface_hub import hf_hub_download
4
  import joblib
5
- from sklearn.preprocessing import LabelEncoder
6
- from sklearn.metrics import accuracy_score
 
 
7
 
8
- def load_model_from_hf_hub(repo_id, filename):
9
  # 从Hugging Face Hub下载模型文件
10
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
11
- # 加载模型
12
  model = joblib.load(model_path)
13
  return model
14
 
15
- def preprocess_data(uploaded_file, features):
16
- # 读取并预处理新数据集
17
- new_data = pd.read_csv(uploaded_file)
18
- X_new = new_data[features]
19
- return X_new, new_data['point_victor']
20
 
21
- def main():
22
- st.title('ML Model Prediction Demo with XGBoost')
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- # 文件上传器
25
- uploaded_file = st.file_uploader("Upload your input CSV file", type="csv")
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  if uploaded_file is not None:
28
- # 指定特征列
29
- features = [
30
- 'elapsed_time', 'set_no', 'game_no', 'game_advantage', 'set_advantage',
31
- 'p1_cumulative_distance', 'p2_cumulative_distance',
32
- 'p1_max_continuous_scoring', 'p2_max_continuous_scoring',
33
- 'p1_total_errors', 'p2_total_errors'
34
- ]
35
- # 指定Hugging Face Hub上模型的仓库ID和文件名
36
  repo_id = "Nagi-ovo/Momentum-XGboost"
37
  filename = "xgboost.pkl"
38
 
39
- # 从Hugging Face Hub加载模型
40
- xgb_model = load_model_from_hf_hub(repo_id, filename)
 
41
 
42
- # 预处理数据
43
- X_new, y_new = preprocess_data(uploaded_file, features)
 
 
44
 
45
- # 如果目标变量也需要编码转换,确保使用与训练数据相同的方式进行转换
46
- label_encoder = LabelEncoder()
47
- y_new_encoded = label_encoder.fit_transform(y_new)
48
-
49
- # 使用加载的模型进行预测
50
- predictions = xgb_model.predict(X_new)
51
-
52
- # 计算并显示正确率
53
- accuracy = accuracy_score(y_new_encoded, predictions)
54
- st.write("Accuracy on new data:", accuracy)
55
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  if __name__ == '__main__':
57
  main()
 
1
+ from huggingface_hub import hf_hub_download
2
  import streamlit as st
3
  import pandas as pd
 
4
  import joblib
5
+ import matplotlib.pyplot as plt
6
+ import seaborn as sns
7
+ from scipy.interpolate import make_interp_spline
8
+ import numpy as np
9
 
10
+ def load_model(repo_id, filename):
11
  # 从Hugging Face Hub下载模型文件
12
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
13
+ # 使用joblib加载模型
14
  model = joblib.load(model_path)
15
  return model
16
 
17
+ # 设置Seaborn的风格
18
+ sns.set(style="whitegrid")
 
 
 
19
 
20
+ def preprocess_and_predict(model, data, match_id, features):
21
+ # 筛选出特定match_id的数据
22
+ specific_match_data = data[data['match_id'] == match_id]
23
+
24
+ if specific_match_data.empty:
25
+ st.write(f"No data found for match_id {match_id}.")
26
+ return None, None
27
+ else:
28
+ specific_match_data = specific_match_data.sort_values('elapsed_time')
29
+
30
+ # 预处理数据集
31
+ X_specific_match = specific_match_data[features]
32
+
33
+ # 使用模型进行概率预测
34
+ positive_class_probabilities = model.predict_proba(X_specific_match)[:, 1]
35
 
36
+ return specific_match_data, positive_class_probabilities
37
+
38
+ def plot_results(specific_match_data, positive_class_probabilities, match_id, show_true_label=True, show_predicted_probability=True, show_momentum=True):
39
+ plt.figure(figsize=(14, 7))
40
+
41
+ if show_predicted_probability:
42
+ # 绘制预测概率的线图
43
+ sns.lineplot(
44
+ x=specific_match_data['elapsed_time'],
45
+ y=positive_class_probabilities,
46
+ marker='o',
47
+ linestyle='-',
48
+ color='blue',
49
+ label='Predicted Probability'
50
+ )
51
 
52
+ if show_true_label:
53
+ # 绘制真实标签的散点图
54
+ sns.scatterplot(
55
+ x=specific_match_data['elapsed_time'],
56
+ y=(specific_match_data['point_victor'] - 1),
57
+ color='red',
58
+ label='True Label',
59
+ s=60
60
+ )
61
+
62
+ if show_momentum:
63
+ # 计算并绘制Momentum
64
+ window_size = 3
65
+ padded_probabilities = np.pad(positive_class_probabilities, (window_size//2, window_size//2), 'edge')
66
+ momentum = np.convolve(padded_probabilities, np.ones(window_size)/window_size, mode='same')[:len(specific_match_data)]
67
+
68
+ # 创建平滑曲线
69
+ X_smooth = np.linspace(specific_match_data['elapsed_time'].min(), specific_match_data['elapsed_time'].max(), 300)
70
+ spline = make_interp_spline(specific_match_data['elapsed_time'], momentum, k=3) # k是平滑曲线的强度
71
+ momentum_smooth = spline(X_smooth)
72
+
73
+ # 绘制Momentum的平滑曲线下面积,半透明
74
+ plt.fill_between(X_smooth, momentum_smooth, color='green', alpha=0.3)
75
+ # 也可以选择绘制平滑曲线的边界,如果需要
76
+ plt.plot(X_smooth, momentum_smooth, color='green', label='Momentum')
77
+
78
+ # 标记set_no和game_no变化的时刻
79
+ for i in range(1, len(specific_match_data)):
80
+ if specific_match_data['game_no'].iloc[i] != specific_match_data['game_no'].iloc[i-1]:
81
+ plt.axvline(
82
+ x=specific_match_data['elapsed_time'].iloc[i],
83
+ color='gray',
84
+ linestyle='--',
85
+ lw=2
86
+ )
87
+ if specific_match_data['set_no'].iloc[i] != specific_match_data['set_no'].iloc[i-1]:
88
+ plt.axvline(
89
+ x=specific_match_data['elapsed_time'].iloc[i],
90
+ color='red',
91
+ linestyle='-.',
92
+ lw=2
93
+ )
94
+
95
+ plt.title(f'Predicted Probability, Momentum, and True Label Over Time for Match {match_id}')
96
+ plt.xlabel('Elapsed Time')
97
+ plt.ylabel('Probability / True Label / Momentum')
98
+ plt.grid(True)
99
+ plt.legend()
100
+ st.pyplot(plt)
101
+
102
+ def main():
103
+ st.title('Momentum Catcher')
104
+
105
+ uploaded_file = st.file_uploader("Upload your input CSV data", type="csv")
106
+ match_id_input = st.text_input("Enter the match_id you want to analyze", "2023-wimbledon-1301")
107
+
108
  if uploaded_file is not None:
 
 
 
 
 
 
 
 
109
  repo_id = "Nagi-ovo/Momentum-XGboost"
110
  filename = "xgboost.pkl"
111
 
112
+ # 加载模型
113
+ xgb_model = load_model(repo_id, filename)
114
+ features = ['PAI_diff', 'normalized_rally', 'is_game_point']
115
 
116
+ # 读取上传的CSV文件
117
+ new_data = pd.read_csv(uploaded_file)
118
+ new_data['is_game_point'] = abs(new_data['p1_facing_game_point'] - new_data['p2_facing_game_point'])
119
+ new_data['PAI_diff'] = new_data['p1_PAI'] - new_data['p2_PAI']
120
 
121
+ specific_match_data, positive_class_probabilities = preprocess_and_predict(xgb_model, new_data, match_id_input, features)
122
+
123
+ if specific_match_data is not None:
124
+ # 允许用户选择是否要观察特定的set和game
125
+ observe_specific = st.checkbox("Observe specific set(s) and game(s)", False)
126
+
127
+ if observe_specific:
128
+ # 用户选择观察多盘
129
+ unique_sets = specific_match_data['set_no'].unique()
130
+ selected_sets = st.multiselect('Select set(s) to observe', unique_sets, default=unique_sets)
131
+
132
+ # 基于选定的set_no,选择观察一局或多局
133
+ if selected_sets:
134
+ filtered_data = specific_match_data[specific_match_data['set_no'].isin(selected_sets)]
135
+ unique_games = filtered_data['game_no'].unique()
136
+ selected_games = st.multiselect('Select games to observe (default is all games)', unique_games, default=unique_games)
137
+
138
+ # 进一步筛选数据以仅包含选定的game_no
139
+ if selected_games:
140
+ filtered_data = filtered_data[filtered_data['game_no'].isin(selected_games)]
141
+ else:
142
+ filtered_data = specific_match_data # 如果没有选择任何set,显示全部数据
143
+ else:
144
+ filtered_data = specific_match_data
145
+ selected_games = specific_match_data['game_no'].unique() # 默认选择所有游戏
146
+
147
+ # 显示选项复选框
148
+ show_true_label = st.checkbox("Show True Label", True)
149
+ show_predicted_probability = st.checkbox("Show Predicted Probability", True)
150
+ show_momentum = st.checkbox("Show Momentum", True)
151
+
152
+ # 如果筛选后的数据不为空,则绘制图表
153
+ if not filtered_data.empty:
154
+ # 计算筛选数据的预测概率
155
+ filtered_positive_class_probabilities = xgb_model.predict_proba(filtered_data[features])[:, 1]
156
+ plot_results(filtered_data, filtered_positive_class_probabilities, match_id_input, show_true_label, show_predicted_probability, show_momentum)
157
+ else:
158
+ st.write("No data available for the selected set(s) and games. Please select again.")
159
+
160
  if __name__ == '__main__':
161
  main()
models/xgboost.pkl DELETED
Binary file (271 kB)