abreza commited on
Commit
14aebdf
·
1 Parent(s): 5d00885
Files changed (7) hide show
  1. .gitignore +3 -0
  2. app.py +16 -67
  3. config.py +58 -0
  4. interface.py +145 -0
  5. model_utils.py +79 -0
  6. requirements.txt +5 -1
  7. styles.py +80 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ venv
2
+ **/__pycache__
3
+ .gradio
app.py CHANGED
@@ -1,71 +1,20 @@
1
- import pickle
2
- import gradio as gr
3
- import shap
4
- import pandas as pd
5
- import numpy as np
6
- import matplotlib.pyplot as plt
7
- import io
8
- import base64
9
 
10
- # Load model
11
- with open("best_model.pkl", "rb") as f:
12
- model = pickle.load(f)
13
-
14
- # Example features - replace with actual features your model uses
15
- feature_names = ["Age",
16
- "weight",
17
- "height",
18
- "BMI",
19
- "gravidity",
20
- "parity",
21
- "H.Abortion",
22
- "living.Child",
23
- "Gestational.Age",
24
- "Hemoglobin",
25
- "hematocrit",
26
- "platelet",
27
- "MPV.mean.platelet.volume",
28
- "PDW.platelet.distribution.width",
29
- "neutrophil",
30
- "lymphocyte",
31
- "NLR.neutrophil.to.lymphocyte",
32
- "PLR.platelet.to.lymphocyte.ratio"
33
- ]
34
-
35
-
36
-
37
- def predict_and_explain(*inputs):
38
- # Create a DataFrame for the input
39
- input_data = pd.DataFrame([inputs], columns=feature_names)
40
 
41
- # Prediction
42
- prediction = model.predict(input_data)[0]
43
-
44
- # SHAP explanation (on the fly)
45
- # explainer = shap.Explainer(model)
46
- # shap_values = explainer(input_data)
47
- explainer = shap.KernelExplainer(model)
48
- shap_values = explainer.shap_values(input_data, nsamples=100)
49
-
50
- # SHAP plot
51
- shap_html = shap.plots.force(shap_values[0], matplotlib=False)
52
- return f"Risk Prediction: {'High' if prediction else 'Low'}", shap_html
53
-
54
-
55
- # Build the Gradio interface
56
- input_components = [gr.Number(label=feat) for feat in feature_names]
57
- output_components = [
58
- gr.Textbox(label="Prediction"),
59
- gr.HTML(label="Feature Importance")
60
- ]
61
-
62
- demo = gr.Interface(
63
- fn=predict_and_explain,
64
- inputs=input_components,
65
- outputs=output_components,
66
- title="Pregnancy Risk Analyzer",
67
- description="Enter patient data to analyze pregnancy risk and see important features using SHAP"
68
- )
69
 
70
  if __name__ == "__main__":
71
- demo.launch()
 
1
+ from interface import create_interface
2
+ from model_utils import load_model
 
 
 
 
 
 
3
 
4
+ def main():
5
+ print("Loading model...")
6
+ model = load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ if model is None:
9
+ print("Warning: Model could not be loaded. Please ensure 'best_model.pkl' exists.")
10
+ else:
11
+ print("Model loaded successfully!")
12
+
13
+ print("Creating interface...")
14
+ demo = create_interface()
15
+
16
+ print("Launching application...")
17
+ demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  if __name__ == "__main__":
20
+ main()
config.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL_PATH = 'best_model.pkl'
2
+
3
+ FEATURE_NAMES = [
4
+ 'Age', 'weight', 'height', 'BMI', 'gravidity', 'parity', 'H.Abortion',
5
+ 'living.Child', 'Gestational.Age', 'Hemoglobin', 'hematocrit', 'platelet',
6
+ 'MPV.mean.platelet.volume', 'PDW.platelet.distribution.width',
7
+ 'neutrophil', 'lymphocyte', 'NLR.neutrophil.to.lymphocyte',
8
+ 'PLR.platelet.to.lymphocyte.ratio'
9
+ ]
10
+
11
+ FEATURE_NAMES_FA = [
12
+ 'سن', 'وزن', 'قد', 'BMI', 'تعداد بارداری', 'تعداد زایمان', 'تعداد سقط',
13
+ 'فرزند زنده', 'سن بارداری', 'هموگلوبین', 'هماتوکریت', 'پلاکت',
14
+ 'MPV', 'PDW', 'نوتروفیل', 'لنفوسیت', 'NLR', 'PLR'
15
+ ]
16
+
17
+ APP_TITLE = "🩺 سیستم پیش‌بینی سلامت جنین"
18
+ MODEL_ACCURACY = "95.8%"
19
+ MODEL_AUC = "99.3%"
20
+
21
+ DEFAULT_VALUES = {
22
+ 'age': None, 'weight': None, 'height': None, 'gravidity': None, 'parity': None,
23
+ 'h_abortion': None, 'living_child': None, 'gestational_age': None,
24
+ 'hemoglobin': None, 'hematocrit': None, 'platelet': None, 'mpv': None,
25
+ 'pdw': None, 'neutrophil': None, 'lymphocyte': None
26
+ }
27
+
28
+ FIELD_RANGES = {
29
+ 'age': {'min': 15, 'max': 60}, 'weight': {'min': 35, 'max': 150},
30
+ 'height': {'min': 130, 'max': 200}, 'gravidity': {'min': 0, 'max': 15},
31
+ 'parity': {'min': 0, 'max': 12}, 'h_abortion': {'min': 0, 'max': 10},
32
+ 'living_child': {'min': 0, 'max': 12}, 'gestational_age': {'min': 1, 'max': 44},
33
+ 'hemoglobin': {'min': 6.0, 'max': 20.0}, 'hematocrit': {'min': 20.0, 'max': 60.0},
34
+ 'platelet': {'min': 50, 'max': 1000}, 'mpv': {'min': 5.0, 'max': 20.0},
35
+ 'pdw': {'min': 8.0, 'max': 30.0}, 'neutrophil': {'min': 0.5, 'max': 15.0},
36
+ 'lymphocyte': {'min': 0.2, 'max': 8.0}
37
+ }
38
+
39
+ EXAMPLE_CASES = {
40
+ "مثال ۱: بیمار کم‌خطر": {
41
+ 'age': 28, 'weight': 68, 'height': 165, 'gravidity': 2, 'parity': 1,
42
+ 'h_abortion': 0, 'living_child': 1, 'gestational_age': 32,
43
+ 'hemoglobin': 12.5, 'hematocrit': 38.0, 'platelet': 280,
44
+ 'mpv': 8.5, 'pdw': 15.2, 'neutrophil': 4.2, 'lymphocyte': 2.1
45
+ },
46
+ "مثال ۲: بیمار پرخطر": {
47
+ 'age': 42, 'weight': 85, 'height': 158, 'gravidity': 5, 'parity': 3,
48
+ 'h_abortion': 1, 'living_child': 3, 'gestational_age': 28,
49
+ 'hemoglobin': 9.2, 'hematocrit': 28.5, 'platelet': 450,
50
+ 'mpv': 11.8, 'pdw': 18.7, 'neutrophil': 7.8, 'lymphocyte': 1.2
51
+ },
52
+ "مثال ۳: بیمار جوان": {
53
+ 'age': 22, 'weight': 58, 'height': 162, 'gravidity': 1, 'parity': 0,
54
+ 'h_abortion': 0, 'living_child': 0, 'gestational_age': 24,
55
+ 'hemoglobin': 11.8, 'hematocrit': 35.2, 'platelet': 220,
56
+ 'mpv': 9.2, 'pdw': 16.1, 'neutrophil': 3.8, 'lymphocyte': 2.5
57
+ }
58
+ }
interface.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from model_utils import predict_outcome
3
+ from styles import RTL_CSS, HTML_HEAD
4
+ from config import (
5
+ APP_TITLE, MODEL_ACCURACY, MODEL_AUC,
6
+ DEFAULT_VALUES, FIELD_RANGES, EXAMPLE_CASES
7
+ )
8
+
9
+ def create_patient_info_section():
10
+ with gr.Column():
11
+ gr.Markdown("### 📝 اطلاعات بیمار")
12
+
13
+ age = gr.Number(label="سن", value=DEFAULT_VALUES['age'],
14
+ minimum=FIELD_RANGES['age']['min'], maximum=FIELD_RANGES['age']['max'])
15
+ weight = gr.Number(label="وزن (کیلوگرم)", value=DEFAULT_VALUES['weight'],
16
+ minimum=FIELD_RANGES['weight']['min'], maximum=FIELD_RANGES['weight']['max'])
17
+ height = gr.Number(label="قد (سانتی‌متر)", value=DEFAULT_VALUES['height'],
18
+ minimum=FIELD_RANGES['height']['min'], maximum=FIELD_RANGES['height']['max'])
19
+
20
+ with gr.Row():
21
+ gravidity = gr.Number(label="تعداد بارداری", value=DEFAULT_VALUES['gravidity'],
22
+ minimum=FIELD_RANGES['gravidity']['min'], maximum=FIELD_RANGES['gravidity']['max'])
23
+ parity = gr.Number(label="تعداد زایمان", value=DEFAULT_VALUES['parity'],
24
+ minimum=FIELD_RANGES['parity']['min'], maximum=FIELD_RANGES['parity']['max'])
25
+
26
+ with gr.Row():
27
+ h_abortion = gr.Number(label="تعداد سقط", value=DEFAULT_VALUES['h_abortion'],
28
+ minimum=FIELD_RANGES['h_abortion']['min'], maximum=FIELD_RANGES['h_abortion']['max'])
29
+ living_child = gr.Number(label="فرزند زنده", value=DEFAULT_VALUES['living_child'],
30
+ minimum=FIELD_RANGES['living_child']['min'], maximum=FIELD_RANGES['living_child']['max'])
31
+
32
+ gestational_age = gr.Number(label="سن بارداری (هفته)", value=DEFAULT_VALUES['gestational_age'],
33
+ minimum=FIELD_RANGES['gestational_age']['min'], maximum=FIELD_RANGES['gestational_age']['max'])
34
+
35
+ return age, weight, height, gravidity, parity, h_abortion, living_child, gestational_age
36
+
37
+ def create_lab_tests_section():
38
+ with gr.Column():
39
+ gr.Markdown("### 🧪 آزمایشات خون")
40
+
41
+ hemoglobin = gr.Number(label="هموگلوبین", value=DEFAULT_VALUES['hemoglobin'],
42
+ minimum=FIELD_RANGES['hemoglobin']['min'], maximum=FIELD_RANGES['hemoglobin']['max'])
43
+ hematocrit = gr.Number(label="هماتوکریت", value=DEFAULT_VALUES['hematocrit'],
44
+ minimum=FIELD_RANGES['hematocrit']['min'], maximum=FIELD_RANGES['hematocrit']['max'])
45
+ platelet = gr.Number(label="پلاکت", value=DEFAULT_VALUES['platelet'],
46
+ minimum=FIELD_RANGES['platelet']['min'], maximum=FIELD_RANGES['platelet']['max'])
47
+
48
+ with gr.Row():
49
+ mpv = gr.Number(label="MPV", value=DEFAULT_VALUES['mpv'],
50
+ minimum=FIELD_RANGES['mpv']['min'], maximum=FIELD_RANGES['mpv']['max'])
51
+ pdw = gr.Number(label="PDW", value=DEFAULT_VALUES['pdw'],
52
+ minimum=FIELD_RANGES['pdw']['min'], maximum=FIELD_RANGES['pdw']['max'])
53
+
54
+ with gr.Row():
55
+ neutrophil = gr.Number(label="نوتروفیل", value=DEFAULT_VALUES['neutrophil'],
56
+ minimum=FIELD_RANGES['neutrophil']['min'], maximum=FIELD_RANGES['neutrophil']['max'])
57
+ lymphocyte = gr.Number(label="لنفوسیت", value=DEFAULT_VALUES['lymphocyte'],
58
+ minimum=FIELD_RANGES['lymphocyte']['min'], maximum=FIELD_RANGES['lymphocyte']['max'])
59
+
60
+ return hemoglobin, hematocrit, platelet, mpv, pdw, neutrophil, lymphocyte
61
+
62
+ def predict_with_explanation(age, weight, height, gravidity, parity, h_abortion,
63
+ living_child, gestational_age, hemoglobin, hematocrit,
64
+ platelet, mpv, pdw, neutrophil, lymphocyte):
65
+
66
+ required_fields = [age, weight, height, gravidity, parity, h_abortion,
67
+ living_child, gestational_age, hemoglobin, hematocrit,
68
+ platelet, mpv, pdw, neutrophil, lymphocyte]
69
+
70
+ if any(field is None or field == "" for field in required_fields):
71
+ return "⚠️ لطفاً تمام فیلدها را پر کنید", "برای پیش‌بینی دقیق، تمام اطلاعات مورد نیاز است.", None
72
+
73
+ result, detailed_report = predict_outcome(
74
+ age, weight, height, gravidity, parity, h_abortion,
75
+ living_child, gestational_age, hemoglobin, hematocrit,
76
+ platelet, mpv, pdw, neutrophil, lymphocyte
77
+ )
78
+
79
+ return result, detailed_report
80
+
81
+ def clear_all_fields():
82
+ return tuple([None] * 17)
83
+
84
+ def load_example(example_name):
85
+ example_data = EXAMPLE_CASES[example_name]
86
+ return tuple(example_data[key] for key in [
87
+ 'age', 'weight', 'height', 'gravidity', 'parity', 'h_abortion',
88
+ 'living_child', 'gestational_age', 'hemoglobin', 'hematocrit',
89
+ 'platelet', 'mpv', 'pdw', 'neutrophil', 'lymphocyte'
90
+ ])
91
+
92
+ def create_interface():
93
+ with gr.Blocks(title=APP_TITLE, theme=gr.themes.Soft(), css=RTL_CSS, head=HTML_HEAD) as demo:
94
+
95
+ gr.Markdown(f"""
96
+ # {APP_TITLE}
97
+
98
+ این سیستم با استفاده از مدل هوش مصنوعی **AdaBoost**، احتمال بروز عوارض در بارداری را پیش‌بینی می‌کند.
99
+
100
+ **📊 عملکرد مدل:** دقت {MODEL_ACCURACY} | AUC {MODEL_AUC}
101
+
102
+ 🔍 **ویژگی‌های سیستم:**
103
+ - پیش‌بینی دقیق با استفاده از هوش مصنوعی
104
+ - تحلیل SHAP برای توضیح تأثیر هر ویژگی
105
+ - گزارش تفصیلی و قابل فهم برای پزشکان
106
+
107
+ 📝 **راهنما:** تمام فیلدها را پر کنید یا از مثال‌های آماده استفاده کنید.
108
+ """)
109
+
110
+ with gr.Row():
111
+ patient_inputs = create_patient_info_section()
112
+ lab_inputs = create_lab_tests_section()
113
+
114
+ with gr.Row():
115
+ predict_btn = gr.Button("🔍 پیش‌بینی", variant="primary", size="lg")
116
+ clear_btn = gr.Button("🗑️ پاک کردن", variant="secondary")
117
+
118
+ with gr.Row():
119
+ with gr.Column(scale=2):
120
+ result_text = gr.Textbox(label="نتیجه پیش‌بینی", lines=2)
121
+ detailed_report = gr.Markdown(label="گزارش تفصیلی")
122
+
123
+ gr.Markdown("---")
124
+ gr.Markdown("## 📚 مثال‌های آماده")
125
+
126
+ with gr.Row():
127
+ for example_name in EXAMPLE_CASES.keys():
128
+ example_btn = gr.Button(f"📋 {example_name}", variant="secondary")
129
+ example_btn.click(
130
+ fn=lambda name=example_name: load_example(name),
131
+ outputs=list(patient_inputs) + list(lab_inputs)
132
+ )
133
+
134
+ predict_btn.click(
135
+ fn=predict_with_explanation,
136
+ inputs=list(patient_inputs) + list(lab_inputs),
137
+ outputs=[result_text, detailed_report]
138
+ )
139
+
140
+ clear_btn.click(
141
+ fn=clear_all_fields,
142
+ outputs=list(patient_inputs) + list(lab_inputs) + [result_text, detailed_report]
143
+ )
144
+
145
+ return demo
model_utils.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import joblib
3
+ import warnings
4
+ from config import MODEL_PATH
5
+
6
+ warnings.filterwarnings('ignore')
7
+
8
+ model = None
9
+
10
+ def load_model():
11
+ global model
12
+ try:
13
+ model = joblib.load(MODEL_PATH)
14
+ return model
15
+ except Exception as e:
16
+ print(f"Error loading model: {e}")
17
+ return None
18
+
19
+ def calculate_derived_features(age, weight, height, neutrophil, lymphocyte, platelet):
20
+ height_m = height / 100
21
+ bmi = weight / (height_m ** 2)
22
+ nlr = neutrophil / lymphocyte if lymphocyte > 0 else 0
23
+ plr = platelet / lymphocyte if lymphocyte > 0 else 0
24
+ return bmi, nlr, plr
25
+
26
+ def predict_outcome(age, weight, height, gravidity, parity, h_abortion,
27
+ living_child, gestational_age, hemoglobin, hematocrit,
28
+ platelet, mpv, pdw, neutrophil, lymphocyte):
29
+ global model
30
+
31
+ if model is None:
32
+ model = load_model()
33
+ if model is None:
34
+ return "خطا: مدل بارگذاری نشد", ""
35
+
36
+ try:
37
+ bmi, nlr, plr = calculate_derived_features(age, weight, height, neutrophil, lymphocyte, platelet)
38
+
39
+ input_data = np.array([[
40
+ age, weight, height, bmi, gravidity, parity, h_abortion,
41
+ living_child, gestational_age, hemoglobin, hematocrit, platelet,
42
+ mpv, pdw, neutrophil, lymphocyte, nlr, plr
43
+ ]])
44
+
45
+ prediction_proba = model.predict_proba(input_data)[0]
46
+ prediction = model.predict(input_data)[0]
47
+
48
+ if prediction == 0:
49
+ result = f"🟢 پیش‌بینی: سالم (احتمال سالم بودن: {prediction_proba[0]*100:.1f}%)"
50
+ risk_level = "کم"
51
+ else:
52
+ result = f"🔴 پیش‌بینی: پرخطر (احتمال عوارض: {prediction_proba[1]*100:.1f}%)"
53
+ risk_level = "بالا"
54
+
55
+ detailed_report = f"""
56
+ 📊 **گزارش تفصیلی پیش‌بینی**
57
+
58
+ **نتیجه کلی:** {result}
59
+
60
+ **سطح ریسک:** {risk_level}
61
+
62
+ **ویژگی‌های محاسبه شده:**
63
+ - BMI: {bmi:.2f}
64
+ - NLR (نسبت نوتروفیل به لنفوسیت): {nlr:.2f}
65
+ - PLR (نسبت پلاکت به لنفوسیت): {plr:.2f}
66
+
67
+ ⚠️ **توجه:** این پیش‌بینی صرفاً جهت کمک به تشخیص است و نباید جایگزین نظر پزشک شود.
68
+ """
69
+
70
+ return result, detailed_report
71
+
72
+ except Exception as e:
73
+ return f"خطا در پردازش: {str(e)}", ""
74
+
75
+ def get_model():
76
+ global model
77
+ if model is None:
78
+ model = load_model()
79
+ return model
requirements.txt CHANGED
@@ -1,6 +1,10 @@
1
  gradio
2
- scikit-learn
3
  pandas
4
  numpy
 
 
5
  shap
6
  matplotlib
 
 
 
 
1
  gradio
 
2
  pandas
3
  numpy
4
+ scikit-learn
5
+ joblib
6
  shap
7
  matplotlib
8
+ xgboost
9
+ catboost
10
+ lightgbm
styles.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ RTL_CSS = """
2
+ /* RTL Support for Persian/Arabic */
3
+ .gradio-container {
4
+ direction: rtl !important;
5
+ text-align: right !important;
6
+ }
7
+
8
+ /* Fix input fields alignment */
9
+ .gr-textbox, .gr-number, .gr-slider {
10
+ direction: rtl !important;
11
+ text-align: right !important;
12
+ }
13
+
14
+ /* Fix labels */
15
+ label {
16
+ direction: rtl !important;
17
+ text-align: right !important;
18
+ }
19
+
20
+ /* Fix buttons */
21
+ .gr-button {
22
+ direction: rtl !important;
23
+ }
24
+
25
+ /* Fix markdown content */
26
+ .gr-markdown {
27
+ direction: rtl !important;
28
+ text-align: right !important;
29
+ }
30
+
31
+ /* Fix specific input elements */
32
+ input[type="number"], input[type="text"], textarea {
33
+ direction: rtl !important;
34
+ text-align: right !important;
35
+ }
36
+
37
+ /* Fix column layouts */
38
+ .gr-column {
39
+ direction: rtl !important;
40
+ }
41
+
42
+ /* Fix row layouts */
43
+ .gr-row {
44
+ direction: rtl !important;
45
+ }
46
+
47
+ /* Fix slider component */
48
+ .gr-slider input {
49
+ direction: ltr !important;
50
+ }
51
+
52
+ /* Ensure proper spacing for Persian text */
53
+ body {
54
+ font-family: 'Tahoma', 'Arial', sans-serif !important;
55
+ direction: rtl !important;
56
+ }
57
+
58
+ /* Fix any remaining LTR elements */
59
+ * {
60
+ direction: inherit;
61
+ }
62
+
63
+ /* Special fixes for gradio components */
64
+ .wrap.svelte-1116kco {
65
+ direction: rtl !important;
66
+ }
67
+
68
+ .container.svelte-1116kco {
69
+ direction: rtl !important;
70
+ }
71
+ """
72
+
73
+ HTML_HEAD = """
74
+ <meta charset="UTF-8">
75
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
76
+ <style>
77
+ @import url('https://fonts.googleapis.com/css2?family=Vazir:wght@300;400;500;600&display=swap');
78
+ body { font-family: 'Vazir', 'Tahoma', Arial, sans-serif !important; }
79
+ </style>
80
+ """