Spaces:
Sleeping
Sleeping
File size: 5,704 Bytes
4bb4a0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
# conda create -n IMH-XGBoost conda-forge::huggingface_hub
# pip install -r requirements.txt -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
import os
# 获取模型
if not os.path.exists('xgb.baseline.model.json'):
from huggingface_hub import login, snapshot_download
login(token=os.environ.get("HF_TOKEN"))
snapshot_download(repo_id='Limour-blog/IMH-XGBoost', local_dir=r'.', allow_patterns='xgb.baseline.model.json')
import xgboost as xgb
import numpy as np
clf = xgb.XGBClassifier(enable_categorical=True)
clf.load_model(r"xgb.baseline.model.json")
def limit(_value, _min, _max):
return min(max(_value, _min), _max)
def args2Array(
BSA=1.824,
CTNT=4.715, # _0
CK_MB=200.5, # _0
CRP=18.01, # _1
PD_DIMER=1.047,
NT_PROBNP=883.6, # _3
ARRHYTHMIA=0,
APOE=36.76,
MHR=0.8378
):
BSA = limit(BSA, 1.401, 2.231)
BSA = (BSA - 1.824) / 0.1654
CTNT = limit(CTNT, -9.566, 19.58)
CTNT = (CTNT - 4.715) / 3.877
CK_MB = limit(CK_MB, -213, 571)
CK_MB = (CK_MB - 200.5) / 154.3
CRP = limit(CRP, -25.04, 55.86)
CRP = (CRP - 18.01) / 17.53
PD_DIMER = limit(PD_DIMER, -1.131, 2.959)
PD_DIMER = (PD_DIMER - 1.047) / 0.8045
NT_PROBNP = limit(NT_PROBNP, -610.1, 2106)
NT_PROBNP = (NT_PROBNP - 883.6) / 625.8
APOE = limit(APOE, 3.625, 68.62)
APOE = (APOE - 36.76) / 13.85
MHR = limit(MHR, -0.06439, 1.683)
MHR = (MHR - 0.8378) / 0.3103
return np.array([[BSA, CTNT, CK_MB,
CRP, PD_DIMER, NT_PROBNP,
ARRHYTHMIA, APOE, MHR]])
def predict(_array):
return float(clf.predict_proba(_array)[0,1])
# 测试模型预测阳性正确
assert predict(args2Array(
BSA=1.99,
CTNT=10, # _0
CK_MB=374, # _0
CRP=14.4, # _1
PD_DIMER=0.88,
NT_PROBNP=463.7, # _3
ARRHYTHMIA=0,
APOE=37,
MHR=0.8378
)) >= 0.72
# 测试模型预测阴性正确
assert predict(args2Array(
BSA=1.51,
CTNT=1.53, # _0
CK_MB=95, # _0
CRP=4.9, # _1
PD_DIMER=1.4,
NT_PROBNP=519.2, # _3
ARRHYTHMIA=0,
APOE=36.76,
MHR=0.5581
)) < 0.72
import gradio as gr
# ========== 完整版的模型 ==========
with gr.Blocks() as complete_model:
with gr.Row():
g_BSA = gr.Number(label="BSA", scale=1, value=1.824,
info="患者的体表面积, 缺失请保持默认值",
interactive=True)
g_ARRHYTHMIA = gr.Checkbox(label="ARRHYTHMIA", scale=1, value=False,
info="患者是否发生恶性心律失常或传导阻滞, 缺失请保持默认值",
interactive=True)
g_PD_DIMER = gr.Number(label="PD_DIMER", scale=1, value=1.047,
info="PCI术后D-二聚体峰值, 缺失请保持默认值",
interactive=True)
with gr.Row():
g_CTNT = gr.Number(label="CTNT", scale=1, value=4.715,
info="PCI术后即刻的CTNT值, 缺失请保持默认值",
interactive=True)
g_CK_MB = gr.Number(label="CK_MB", scale=1, value=200.5,
info="PCI术后即刻的CK_MB值, 缺失请保持默认值",
interactive=True)
g_NT_PROBNP = gr.Number(label="NT_PROBNP", scale=1, value=883.6,
info="PCI术后36小时的NT_PROBNP值, 缺失请保持默认值",
interactive=True)
with gr.Row():
g_CRP = gr.Number(label="CRP", scale=1, value=18.01,
info="PCI术后24小时的CRP值, 缺失请保持默认值",
interactive=True)
g_APOE = gr.Number(label="APOE", scale=1, value=36.76,
info="患者血脂APOE值, 缺失请保持默认值",
interactive=True)
g_MHR = gr.Number(label="MHR", scale=1, value=0.8378,
info="单核细胞与高密度脂蛋白胆固醇比值, 缺失请保持默认值",
interactive=True)
with gr.Row():
g_output1 = gr.Number(label="XGB.predict_proba", scale=1, interactive=False, info="cutoff值为0.72")
g_output2 = gr.Textbox(label="结论", scale=1, interactive=False, info="预测患者IMH为阳性或阴性")
g_calc = gr.Button("计算", variant="primary", size='lg')
def btn_calc(
BSA, CTNT, CK_MB,
CRP, PD_DIMER, NT_PROBNP,
ARRHYTHMIA, APOE, MHR
):
res1 = predict(args2Array(
BSA=BSA,
CTNT=CTNT, # _0
CK_MB=CK_MB, # _0
CRP=CRP, # _1
PD_DIMER=PD_DIMER,
NT_PROBNP=NT_PROBNP, # _3
ARRHYTHMIA = (1 if ARRHYTHMIA else 0),
APOE=APOE,
MHR=MHR
))
if res1 >= 0.72:
res2 = '阳性'
else:
res2 = '阴性'
return round(res1, 4), res2
g_calc.click(
fn = btn_calc,
inputs=[g_BSA, g_CTNT, g_CK_MB,
g_CRP, g_PD_DIMER, g_NT_PROBNP,
g_ARRHYTHMIA, g_APOE, g_MHR],
outputs=[g_output1, g_output2]
)
# ========== 开始运行 ==========
demo = gr.TabbedInterface([complete_model],
["complete_model"])
gr.close_all()
demo.queue(api_open=False, max_size=1).launch(
server_name = "0.0.0.0",
share=False, show_error=True, show_api=False)
|