File size: 8,497 Bytes
e2947cc
1aa7396
 
 
e2947cc
 
 
 
86d7e74
 
 
1aa7396
e2947cc
86d7e74
 
e2947cc
1aa7396
 
 
e2947cc
 
1aa7396
e2947cc
 
 
7456358
 
e2947cc
 
 
 
 
 
 
 
 
 
 
 
1aa7396
e2947cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1aa7396
e2947cc
 
 
 
 
 
 
 
 
 
 
639ccd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2947cc
 
639ccd9
e2947cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
037c01c
 
 
 
 
 
e2947cc
7456358
e2947cc
1aa7396
7456358
 
 
 
 
 
 
 
 
 
2db0224
 
1aa7396
e2947cc
 
 
1aa7396
e2947cc
 
1aa7396
e2947cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1aa7396
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
from huggingface_hub import hf_hub_download
import streamlit as st
import pandas as pd
import joblib
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.interpolate import make_interp_spline
import numpy as np
import os

hf_token = os.getenv("HF_TOKEN")

def load_model(repo_id, filename):
    # 使用环境变量中的HF_TOKEN下载模型文件
    model_path = hf_hub_download(repo_id=repo_id, filename=filename, use_auth_token=hf_token)
    # 使用joblib加载模型
    model = joblib.load(model_path)
    return model

# 设置Seaborn的风格
sns.set(style="whitegrid")

def preprocess_and_predict(model, data, match_id, features):
    # 筛选出特定match_id的数据
    specific_match_data = data[data['match_id'] == match_id]
    specific_match_data = specific_match_data.drop_duplicates(subset=['elapsed_time'], keep='first')
    specific_match_data = specific_match_data[specific_match_data['server'] != 0]

    if specific_match_data.empty:
        st.write(f"No data found for match_id {match_id}.")
        return None, None
    else:
        specific_match_data = specific_match_data.sort_values('elapsed_time')

        # 预处理数据集
        X_specific_match = specific_match_data[features]

        # 使用模型进行概率预测
        positive_class_probabilities = model.predict_proba(X_specific_match)[:, 1]

        return specific_match_data, positive_class_probabilities

def plot_results(specific_match_data, positive_class_probabilities, match_id, show_true_label=True, show_predicted_probability=True, show_momentum=True):
    plt.figure(figsize=(14, 7))
    
    if show_predicted_probability:
        # 绘制预测概率的线图
        sns.lineplot(
            x=specific_match_data['elapsed_time'], 
            y=positive_class_probabilities, 
            marker='o', 
            linestyle='-', 
            color='blue', 
            label='Predicted Probability'
        )
    
    if show_true_label:
        # 绘制真实标签的散点图
        sns.scatterplot(
            x=specific_match_data['elapsed_time'], 
            y=(specific_match_data['point_victor'] - 1), 
            color='red', 
            label='True Label', 
            s=60
        )

    if show_momentum:
        # 修改动能计算逻辑,以确保长度一致
        adjusted_probabilities = []
        for i in range(1, len(positive_class_probabilities)):
            # 使用当前点和前一个点的概率值,假定下一个点的概率为0.5
            if i < len(positive_class_probabilities) - 1:
                next_prob = 0.5  # 对于除最后一个点外的所有点,假定下一个点的概率为0.5
            else:
                next_prob = positive_class_probabilities[i]  # 对于最后一个点,使用其本身的概率
            adjusted_probabilities.append((positive_class_probabilities[i-1] + positive_class_probabilities[i] + next_prob) / 3)
        
        # 对于第一个点,我们可以选择使用它自己的概率,因为没有前一个点
        adjusted_probabilities.insert(0, (positive_class_probabilities[0] + 0.5) / 2)  # 在开始处插入
        
        # 确保adjusted_probabilities的长度与specific_match_data['elapsed_time']一致
        momentum = np.array(adjusted_probabilities[:len(specific_match_data)])
        
        X_smooth = np.linspace(specific_match_data['elapsed_time'].min(), specific_match_data['elapsed_time'].max(), 300)
        spline = make_interp_spline(specific_match_data['elapsed_time'], momentum, k=3)
        momentum_smooth = spline(X_smooth)
        
        plt.fill_between(X_smooth, momentum_smooth, color='green', alpha=0.3)
        plt.plot(X_smooth, momentum_smooth, color='green', label='Momentum')
    
    # 标记set_no和game_no变化的时刻
    for i in range(1, len(specific_match_data)):
        if specific_match_data['game_no'].iloc[i] != specific_match_data['game_no'].iloc[i-1]:
            plt.axvline(
                x=specific_match_data['elapsed_time'].iloc[i], 
                color='gray', 
                linestyle='--', 
                lw=2
            )
        if specific_match_data['set_no'].iloc[i] != specific_match_data['set_no'].iloc[i-1]:
            plt.axvline(
                x=specific_match_data['elapsed_time'].iloc[i], 
                color='red', 
                linestyle='-.', 
                lw=2
            )
    
    plt.title(f'Predicted Probability, Momentum, and True Label Over Time for Match {match_id}')
    plt.xlabel('Elapsed Time')
    plt.ylabel('Probability / True Label / Momentum')
    plt.grid(True)
    plt.legend()
    st.pyplot(plt)

def main():
    st.title('Momentum Catcher')

    st.markdown("""
        To get started, you can find sample data available for download at 
        [Hugging Face Spaces](https://huggingface.co/spaces/Nagi-ovo/Tennis-Momentum-Tracker/tree/main/data).
        This data can be used directly in this application to analyze tennis match momentum.
    """)
    
    uploaded_file = st.file_uploader("Upload your input CSV data", type="csv")
    # match_id_input = st.text_input("Enter the match_id you want to analyze", "2023-wimbledon-1301")

    if uploaded_file is not None:
        new_data = pd.read_csv(uploaded_file)
        new_data.dropna()
        
        # 新增:提取所有唯一的match_id
        unique_match_ids = new_data['match_id'].unique()
        
        # 新增:让用户从所有match_id中选择一个
        match_id_input = st.selectbox("Select the match_id you want to analyze", unique_match_ids)

        # 模型信息
        repo_id = "Nagi-ovo/Momentum-XGboost"
        filename = "xgboost.pkl"
        
        # 加载模型
        xgb_model = load_model(repo_id, filename)
        features = ['PAI_diff', 'normalized_rally', 'is_game_point']
        
        new_data['is_game_point'] = abs(new_data['p1_facing_game_point'] - new_data['p2_facing_game_point'])
        new_data['PAI_diff'] = new_data['p1_PAI'] - new_data['p2_PAI']
        
        specific_match_data, positive_class_probabilities = preprocess_and_predict(xgb_model, new_data, match_id_input, features)
        
        if specific_match_data is not None:
            # 允许用户选择是否要观察特定的set和game
            observe_specific = st.checkbox("Observe specific set(s) and game(s)", False)
            
            if observe_specific:
                # 用户选择观察多盘
                unique_sets = specific_match_data['set_no'].unique()
                selected_sets = st.multiselect('Select set(s) to observe', unique_sets, default=unique_sets)
                
                # 基于选定的set_no,选择观察一局或多局
                if selected_sets:
                    filtered_data = specific_match_data[specific_match_data['set_no'].isin(selected_sets)]
                    unique_games = filtered_data['game_no'].unique()
                    selected_games = st.multiselect('Select games to observe (default is all games)', unique_games, default=unique_games)
                    
                    # 进一步筛选数据以仅包含选定的game_no
                    if selected_games:
                        filtered_data = filtered_data[filtered_data['game_no'].isin(selected_games)]
                else:
                    filtered_data = specific_match_data  # 如果没有选择任何set,显示全部数据
            else:
                filtered_data = specific_match_data
                selected_games = specific_match_data['game_no'].unique()  # 默认选择所有游戏
            
            # 显示选项复选框
            show_true_label = st.checkbox("Show True Label", True)
            show_predicted_probability = st.checkbox("Show Predicted Probability", True)
            show_momentum = st.checkbox("Show Momentum", True)
            
            # 如果筛选后的数据不为空,则绘制图表
            if not filtered_data.empty:
                # 计算筛选数据的预测概率
                filtered_positive_class_probabilities = xgb_model.predict_proba(filtered_data[features])[:, 1]
                plot_results(filtered_data, filtered_positive_class_probabilities, match_id_input, show_true_label, show_predicted_probability, show_momentum)
            else:
                st.write("No data available for the selected set(s) and games. Please select again.")
            
if __name__ == '__main__':
    main()