Limour-blog commited on
Commit
4bb4a0c
·
verified ·
1 Parent(s): 1f952ea

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +153 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # conda create -n IMH-XGBoost conda-forge::huggingface_hub
2
+ # pip install -r requirements.txt -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
3
+ import os
4
+ # 获取模型
5
+ if not os.path.exists('xgb.baseline.model.json'):
6
+ from huggingface_hub import login, snapshot_download
7
+ login(token=os.environ.get("HF_TOKEN"))
8
+ snapshot_download(repo_id='Limour-blog/IMH-XGBoost', local_dir=r'.', allow_patterns='xgb.baseline.model.json')
9
+
10
+ import xgboost as xgb
11
+ import numpy as np
12
+ clf = xgb.XGBClassifier(enable_categorical=True)
13
+ clf.load_model(r"xgb.baseline.model.json")
14
+
15
+ def limit(_value, _min, _max):
16
+ return min(max(_value, _min), _max)
17
+
18
+ def args2Array(
19
+ BSA=1.824,
20
+ CTNT=4.715, # _0
21
+ CK_MB=200.5, # _0
22
+ CRP=18.01, # _1
23
+ PD_DIMER=1.047,
24
+ NT_PROBNP=883.6, # _3
25
+ ARRHYTHMIA=0,
26
+ APOE=36.76,
27
+ MHR=0.8378
28
+ ):
29
+ BSA = limit(BSA, 1.401, 2.231)
30
+ BSA = (BSA - 1.824) / 0.1654
31
+ CTNT = limit(CTNT, -9.566, 19.58)
32
+ CTNT = (CTNT - 4.715) / 3.877
33
+ CK_MB = limit(CK_MB, -213, 571)
34
+ CK_MB = (CK_MB - 200.5) / 154.3
35
+ CRP = limit(CRP, -25.04, 55.86)
36
+ CRP = (CRP - 18.01) / 17.53
37
+ PD_DIMER = limit(PD_DIMER, -1.131, 2.959)
38
+ PD_DIMER = (PD_DIMER - 1.047) / 0.8045
39
+ NT_PROBNP = limit(NT_PROBNP, -610.1, 2106)
40
+ NT_PROBNP = (NT_PROBNP - 883.6) / 625.8
41
+ APOE = limit(APOE, 3.625, 68.62)
42
+ APOE = (APOE - 36.76) / 13.85
43
+ MHR = limit(MHR, -0.06439, 1.683)
44
+ MHR = (MHR - 0.8378) / 0.3103
45
+ return np.array([[BSA, CTNT, CK_MB,
46
+ CRP, PD_DIMER, NT_PROBNP,
47
+ ARRHYTHMIA, APOE, MHR]])
48
+
49
+ def predict(_array):
50
+ return float(clf.predict_proba(_array)[0,1])
51
+
52
+ # 测试模型预测阳性正确
53
+ assert predict(args2Array(
54
+ BSA=1.99,
55
+ CTNT=10, # _0
56
+ CK_MB=374, # _0
57
+ CRP=14.4, # _1
58
+ PD_DIMER=0.88,
59
+ NT_PROBNP=463.7, # _3
60
+ ARRHYTHMIA=0,
61
+ APOE=37,
62
+ MHR=0.8378
63
+ )) >= 0.72
64
+
65
+ # 测试模型预测阴性正确
66
+ assert predict(args2Array(
67
+ BSA=1.51,
68
+ CTNT=1.53, # _0
69
+ CK_MB=95, # _0
70
+ CRP=4.9, # _1
71
+ PD_DIMER=1.4,
72
+ NT_PROBNP=519.2, # _3
73
+ ARRHYTHMIA=0,
74
+ APOE=36.76,
75
+ MHR=0.5581
76
+ )) < 0.72
77
+
78
+ import gradio as gr
79
+
80
+ # ========== 完整版的模型 ==========
81
+ with gr.Blocks() as complete_model:
82
+ with gr.Row():
83
+ g_BSA = gr.Number(label="BSA", scale=1, value=1.824,
84
+ info="患者的体表面积, 缺失请保持默认值",
85
+ interactive=True)
86
+ g_ARRHYTHMIA = gr.Checkbox(label="ARRHYTHMIA", scale=1, value=False,
87
+ info="患者是否发生恶性心律失常或传导阻滞, 缺失请保持默认值",
88
+ interactive=True)
89
+ g_PD_DIMER = gr.Number(label="PD_DIMER", scale=1, value=1.047,
90
+ info="PCI术后D-二聚体峰值, 缺失请保持默认值",
91
+ interactive=True)
92
+ with gr.Row():
93
+ g_CTNT = gr.Number(label="CTNT", scale=1, value=4.715,
94
+ info="PCI术后即刻的CTNT值, 缺失请保持默认值",
95
+ interactive=True)
96
+ g_CK_MB = gr.Number(label="CK_MB", scale=1, value=200.5,
97
+ info="PCI术后即刻的CK_MB值, 缺失请保持默认值",
98
+ interactive=True)
99
+ g_NT_PROBNP = gr.Number(label="NT_PROBNP", scale=1, value=883.6,
100
+ info="PCI术后36小时的NT_PROBNP值, 缺失请保持默认值",
101
+ interactive=True)
102
+ with gr.Row():
103
+ g_CRP = gr.Number(label="CRP", scale=1, value=18.01,
104
+ info="PCI术后24小时的CRP值, 缺失请保持默认值",
105
+ interactive=True)
106
+ g_APOE = gr.Number(label="APOE", scale=1, value=36.76,
107
+ info="患者血脂APOE值, 缺失请保持默认值",
108
+ interactive=True)
109
+ g_MHR = gr.Number(label="MHR", scale=1, value=0.8378,
110
+ info="单核细胞与高密度脂蛋白胆固醇比值, 缺失请保持默认值",
111
+ interactive=True)
112
+ with gr.Row():
113
+ g_output1 = gr.Number(label="XGB.predict_proba", scale=1, interactive=False, info="cutoff值为0.72")
114
+ g_output2 = gr.Textbox(label="结论", scale=1, interactive=False, info="预测患者IMH为阳性或阴性")
115
+ g_calc = gr.Button("计算", variant="primary", size='lg')
116
+ def btn_calc(
117
+ BSA, CTNT, CK_MB,
118
+ CRP, PD_DIMER, NT_PROBNP,
119
+ ARRHYTHMIA, APOE, MHR
120
+ ):
121
+ res1 = predict(args2Array(
122
+ BSA=BSA,
123
+ CTNT=CTNT, # _0
124
+ CK_MB=CK_MB, # _0
125
+ CRP=CRP, # _1
126
+ PD_DIMER=PD_DIMER,
127
+ NT_PROBNP=NT_PROBNP, # _3
128
+ ARRHYTHMIA = (1 if ARRHYTHMIA else 0),
129
+ APOE=APOE,
130
+ MHR=MHR
131
+ ))
132
+ if res1 >= 0.72:
133
+ res2 = '阳性'
134
+ else:
135
+ res2 = '阴性'
136
+ return round(res1, 4), res2
137
+
138
+ g_calc.click(
139
+ fn = btn_calc,
140
+ inputs=[g_BSA, g_CTNT, g_CK_MB,
141
+ g_CRP, g_PD_DIMER, g_NT_PROBNP,
142
+ g_ARRHYTHMIA, g_APOE, g_MHR],
143
+ outputs=[g_output1, g_output2]
144
+ )
145
+
146
+ # ========== 开始运行 ==========
147
+ demo = gr.TabbedInterface([complete_model],
148
+ ["complete_model"])
149
+ gr.close_all()
150
+ demo.queue(api_open=False, max_size=1).launch(
151
+ server_name = "0.0.0.0",
152
+ share=False, show_error=True, show_api=False)
153
+
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ scikit-learn==1.5.2
2
+ xgboost-cpu==2.1.3
3
+ gradio==5.9.1