mgbam commited on
Commit
a66396a
·
verified ·
1 Parent(s): 025950a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -96
app.py CHANGED
@@ -2,104 +2,45 @@ import gradio as gr
2
  import tensorflow as tf
3
  from transformers import TFAutoModel, AutoTokenizer
4
  import numpy as np
5
- import shap
6
- from scipy.special import softmax
7
 
8
- ## Model and Tokenizer Setup
9
- MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
10
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
11
- model = TFAutoModel.from_pretrained(MODEL_NAME)
12
 
13
- # Constants
14
- SEQ_LEN = 128
15
- CONDITIONS = [
16
- "Common Cold", "COVID-19", "Allergies", "Anxiety Disorder", "Skin Infection",
17
- "Heart Condition", "Digestive Issues", "Migraine", "Muscle Strain", "Arthritis"
 
 
 
 
 
 
18
  ]
19
 
20
- # Dynamic Condition Predictions
21
- def predict_condition(description: str):
22
- tokens = tokenizer(
23
- description, max_length=SEQ_LEN, truncation=True, padding="max_length", return_tensors="tf"
24
- )
25
- outputs = model(tokens).last_hidden_state[:, 0, :] # CLS token output
26
- scores = softmax(outputs.numpy())
27
- predictions = dict(zip(CONDITIONS, scores.flatten()))
28
- return predictions
29
-
30
- # Lifestyle Tips
31
- LIFESTYLE_TIPS = {
32
- "Common Cold": "Rest, stay hydrated, and use saline nasal sprays.",
33
- "COVID-19": "Quarantine, stay hydrated, and seek medical attention if symptoms worsen.",
34
- "Allergies": "Avoid allergens, take antihistamines, and use air purifiers.",
35
- "Anxiety Disorder": "Practice mindfulness, exercise, and seek therapy if needed.",
36
- "Skin Infection": "Keep the area clean, use topical creams, and consult a dermatologist.",
37
- # Add more conditions and tips...
38
- }
39
-
40
- def get_lifestyle_advice(condition: str):
41
- return LIFESTYLE_TIPS.get(condition, "Consult a healthcare professional for guidance.")
42
-
43
- # Interactive Health Visualization (SHAP)
44
- def explain_prediction(text: str):
45
- explainer = shap.Explainer(lambda x: predict_condition(x), tokenizer)
46
- shap_values = explainer([text])
47
- return shap.plots.text(shap_values, display=False)
48
-
49
- # Symptom Tracker (Simple Implementation)
50
- symptom_history = []
51
-
52
- def log_symptom(symptom: str):
53
- symptom_history.append(symptom)
54
- return f"Logged: {symptom}. Total symptoms logged: {len(symptom_history)}"
55
-
56
- def display_symptom_trends():
57
- return "\n".join(symptom_history[-10:]) # Last 10 logged symptoms
58
-
59
- # Gradio UI Design
60
- css = """
61
- textarea { background-color: transparent; border: 1px solid #6366f1; }
62
- """
63
- with gr.Blocks(title="MedAI Compass", css=css, theme=gr.themes.Soft()) as app:
64
- # Header
65
- gr.HTML("<h1>MedAI Compass: Comprehensive Symptom and Health Guide</h1>")
66
-
67
- # Section: Symptom Diagnosis
68
- with gr.Row():
69
- gr.Markdown("## Symptom Diagnosis")
70
- input_description = gr.Textbox(label="Describe your symptom")
71
- diagnose_btn = gr.Button("Diagnose")
72
- diagnosis_output = gr.Label(label="Possible Conditions")
73
- diagnose_btn.click(predict_condition, inputs=input_description, outputs=diagnosis_output)
74
-
75
- # Section: SHAP Analysis
76
- with gr.Row():
77
- gr.Markdown("## Explain Predictions")
78
- shap_text_input = gr.Textbox(label="Enter Symptom Description for Analysis")
79
- shap_btn = gr.Button("Generate Explanation")
80
- shap_output = gr.HTML()
81
- shap_btn.click(explain_prediction, inputs=shap_text_input, outputs=shap_output)
82
-
83
- # Section: Personalized Advice
84
- with gr.Row():
85
- gr.Markdown("## Personalized Health Advice")
86
- condition_input = gr.Dropdown(choices=CONDITIONS, label="Select a Condition")
87
- advice_output = gr.Textbox(label="Advice")
88
- advice_btn = gr.Button("Get Advice")
89
- advice_btn.click(get_lifestyle_advice, inputs=condition_input, outputs=advice_output)
90
-
91
- # Section: Symptom Tracker
92
- with gr.Row():
93
- gr.Markdown("## Symptom Tracker")
94
- tracker_input = gr.Textbox(label="Log a Symptom")
95
- tracker_btn = gr.Button("Log Symptom")
96
- tracker_output = gr.Textbox(label="Logged Symptoms")
97
- tracker_btn.click(log_symptom, inputs=tracker_input, outputs=tracker_output)
98
- tracker_display_btn = gr.Button("Display Trends")
99
- tracker_trends_output = gr.Textbox(label="Symptom Trends")
100
- tracker_display_btn.click(display_symptom_trends, outputs=tracker_trends_output)
101
-
102
- # Footer
103
- gr.HTML("<p>© 2024 MedAI Compass. All Rights Reserved.</p>")
104
-
105
- app.launch()
 
2
  import tensorflow as tf
3
  from transformers import TFAutoModel, AutoTokenizer
4
  import numpy as np
 
 
5
 
6
+ # Load pre-trained model
7
+ MODEL_NAME = "cardiffnlp/twitter-roberta-base-sentiment-latest"
8
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
9
 
10
+ # Ensure model loads with TensorFlow and compatibility fixes
11
+ model = tf.keras.models.load_model("model.h5", custom_objects={
12
+ "TFRobertaModel": TFAutoModel.from_pretrained(MODEL_NAME)
13
+ })
14
+
15
+ # Labels for predictions
16
+ LABELS = [
17
+ "Cardiologist", "Dermatologist", "ENT Specialist", "Gastroenterologist",
18
+ "General Physicians", "Neurologist", "Ophthalmologist",
19
+ "Orthopedist", "Psychiatrist", "Respirologist", "Rheumatologist",
20
+ "Surgeon"
21
  ]
22
 
23
+ # Preprocess input data
24
+ def preprocess_input(text):
25
+ tokens = tokenizer(text, max_length=128, truncation=True, padding="max_length", return_tensors="tf")
26
+ return {"input_ids": tokens["input_ids"], "attention_mask": tokens["attention_mask"]}
27
+
28
+ # Predict from input text
29
+ def predict_specialist(text):
30
+ inputs = preprocess_input(text)
31
+ predictions = model.predict(inputs)
32
+ return {LABELS[i]: float(predictions[0][i]) for i in range(len(LABELS))}
33
+
34
+ # Gradio UI
35
+ def build_interface():
36
+ with gr.Blocks() as demo:
37
+ gr.Markdown("## Welcome to FlinShaHealth")
38
+ text_input = gr.Textbox(label="Describe your symptoms:")
39
+ output_label = gr.Label(label="Predicted Specialist")
40
+ submit_btn = gr.Button("Predict")
41
+ submit_btn.click(predict_specialist, inputs=text_input, outputs=output_label)
42
+ return demo
43
+
44
+ if __name__ == "__main__":
45
+ app = build_interface()
46
+ app.launch()