Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,104 +2,45 @@ import gradio as gr
|
|
2 |
import tensorflow as tf
|
3 |
from transformers import TFAutoModel, AutoTokenizer
|
4 |
import numpy as np
|
5 |
-
import shap
|
6 |
-
from scipy.special import softmax
|
7 |
|
8 |
-
|
9 |
-
MODEL_NAME = "
|
10 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
11 |
-
model = TFAutoModel.from_pretrained(MODEL_NAME)
|
12 |
|
13 |
-
#
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
]
|
19 |
|
20 |
-
#
|
21 |
-
def
|
22 |
-
tokens = tokenizer(
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
def explain_prediction(text: str):
|
45 |
-
explainer = shap.Explainer(lambda x: predict_condition(x), tokenizer)
|
46 |
-
shap_values = explainer([text])
|
47 |
-
return shap.plots.text(shap_values, display=False)
|
48 |
-
|
49 |
-
# Symptom Tracker (Simple Implementation)
|
50 |
-
symptom_history = []
|
51 |
-
|
52 |
-
def log_symptom(symptom: str):
|
53 |
-
symptom_history.append(symptom)
|
54 |
-
return f"Logged: {symptom}. Total symptoms logged: {len(symptom_history)}"
|
55 |
-
|
56 |
-
def display_symptom_trends():
|
57 |
-
return "\n".join(symptom_history[-10:]) # Last 10 logged symptoms
|
58 |
-
|
59 |
-
# Gradio UI Design
|
60 |
-
css = """
|
61 |
-
textarea { background-color: transparent; border: 1px solid #6366f1; }
|
62 |
-
"""
|
63 |
-
with gr.Blocks(title="MedAI Compass", css=css, theme=gr.themes.Soft()) as app:
|
64 |
-
# Header
|
65 |
-
gr.HTML("<h1>MedAI Compass: Comprehensive Symptom and Health Guide</h1>")
|
66 |
-
|
67 |
-
# Section: Symptom Diagnosis
|
68 |
-
with gr.Row():
|
69 |
-
gr.Markdown("## Symptom Diagnosis")
|
70 |
-
input_description = gr.Textbox(label="Describe your symptom")
|
71 |
-
diagnose_btn = gr.Button("Diagnose")
|
72 |
-
diagnosis_output = gr.Label(label="Possible Conditions")
|
73 |
-
diagnose_btn.click(predict_condition, inputs=input_description, outputs=diagnosis_output)
|
74 |
-
|
75 |
-
# Section: SHAP Analysis
|
76 |
-
with gr.Row():
|
77 |
-
gr.Markdown("## Explain Predictions")
|
78 |
-
shap_text_input = gr.Textbox(label="Enter Symptom Description for Analysis")
|
79 |
-
shap_btn = gr.Button("Generate Explanation")
|
80 |
-
shap_output = gr.HTML()
|
81 |
-
shap_btn.click(explain_prediction, inputs=shap_text_input, outputs=shap_output)
|
82 |
-
|
83 |
-
# Section: Personalized Advice
|
84 |
-
with gr.Row():
|
85 |
-
gr.Markdown("## Personalized Health Advice")
|
86 |
-
condition_input = gr.Dropdown(choices=CONDITIONS, label="Select a Condition")
|
87 |
-
advice_output = gr.Textbox(label="Advice")
|
88 |
-
advice_btn = gr.Button("Get Advice")
|
89 |
-
advice_btn.click(get_lifestyle_advice, inputs=condition_input, outputs=advice_output)
|
90 |
-
|
91 |
-
# Section: Symptom Tracker
|
92 |
-
with gr.Row():
|
93 |
-
gr.Markdown("## Symptom Tracker")
|
94 |
-
tracker_input = gr.Textbox(label="Log a Symptom")
|
95 |
-
tracker_btn = gr.Button("Log Symptom")
|
96 |
-
tracker_output = gr.Textbox(label="Logged Symptoms")
|
97 |
-
tracker_btn.click(log_symptom, inputs=tracker_input, outputs=tracker_output)
|
98 |
-
tracker_display_btn = gr.Button("Display Trends")
|
99 |
-
tracker_trends_output = gr.Textbox(label="Symptom Trends")
|
100 |
-
tracker_display_btn.click(display_symptom_trends, outputs=tracker_trends_output)
|
101 |
-
|
102 |
-
# Footer
|
103 |
-
gr.HTML("<p>© 2024 MedAI Compass. All Rights Reserved.</p>")
|
104 |
-
|
105 |
-
app.launch()
|
|
|
2 |
import tensorflow as tf
|
3 |
from transformers import TFAutoModel, AutoTokenizer
|
4 |
import numpy as np
|
|
|
|
|
5 |
|
6 |
+
# Load pre-trained model
|
7 |
+
MODEL_NAME = "cardiffnlp/twitter-roberta-base-sentiment-latest"
|
8 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
|
9 |
|
10 |
+
# Ensure model loads with TensorFlow and compatibility fixes
|
11 |
+
model = tf.keras.models.load_model("model.h5", custom_objects={
|
12 |
+
"TFRobertaModel": TFAutoModel.from_pretrained(MODEL_NAME)
|
13 |
+
})
|
14 |
+
|
15 |
+
# Labels for predictions
|
16 |
+
LABELS = [
|
17 |
+
"Cardiologist", "Dermatologist", "ENT Specialist", "Gastroenterologist",
|
18 |
+
"General Physicians", "Neurologist", "Ophthalmologist",
|
19 |
+
"Orthopedist", "Psychiatrist", "Respirologist", "Rheumatologist",
|
20 |
+
"Surgeon"
|
21 |
]
|
22 |
|
23 |
+
# Preprocess input data
|
24 |
+
def preprocess_input(text):
|
25 |
+
tokens = tokenizer(text, max_length=128, truncation=True, padding="max_length", return_tensors="tf")
|
26 |
+
return {"input_ids": tokens["input_ids"], "attention_mask": tokens["attention_mask"]}
|
27 |
+
|
28 |
+
# Predict from input text
|
29 |
+
def predict_specialist(text):
|
30 |
+
inputs = preprocess_input(text)
|
31 |
+
predictions = model.predict(inputs)
|
32 |
+
return {LABELS[i]: float(predictions[0][i]) for i in range(len(LABELS))}
|
33 |
+
|
34 |
+
# Gradio UI
|
35 |
+
def build_interface():
|
36 |
+
with gr.Blocks() as demo:
|
37 |
+
gr.Markdown("## Welcome to FlinShaHealth")
|
38 |
+
text_input = gr.Textbox(label="Describe your symptoms:")
|
39 |
+
output_label = gr.Label(label="Predicted Specialist")
|
40 |
+
submit_btn = gr.Button("Predict")
|
41 |
+
submit_btn.click(predict_specialist, inputs=text_input, outputs=output_label)
|
42 |
+
return demo
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
app = build_interface()
|
46 |
+
app.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|