Update app.py
Browse files
app.py
CHANGED
@@ -18,6 +18,7 @@ from sklearn.tree import DecisionTreeClassifier
|
|
18 |
from sklearn.ensemble import RandomForestClassifier
|
19 |
from sklearn.naive_bayes import GaussianNB
|
20 |
from sklearn.metrics import accuracy_score
|
|
|
21 |
|
22 |
# Suppress TensorFlow warnings
|
23 |
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
@@ -49,14 +50,17 @@ model_sentiment = AutoModelForSequenceClassification.from_pretrained("cardiffnlp
|
|
49 |
tokenizer_emotion = AutoTokenizer.from_pretrained("j-hartmann/emotion-english-distilroberta-base")
|
50 |
model_emotion = AutoModelForSequenceClassification.from_pretrained("j-hartmann/emotion-english-distilroberta-base")
|
51 |
|
|
|
|
|
|
|
52 |
# Google Maps API Client
|
53 |
gmaps = googlemaps.Client(key=os.getenv("GOOGLE_API_KEY"))
|
54 |
|
55 |
# Load the disease dataset
|
56 |
df_train = pd.read_csv("Training.csv") # Change the file path as necessary
|
57 |
-
df_test = pd.read_csv("Testing.csv")
|
58 |
|
59 |
-
# Encode diseases
|
60 |
disease_dict = {
|
61 |
'Fungal infection': 0, 'Allergy': 1, 'GERD': 2, 'Chronic cholestasis': 3, 'Drug Reaction': 4,
|
62 |
'Peptic ulcer disease': 5, 'AIDS': 6, 'Diabetes ': 7, 'Gastroenteritis': 8, 'Bronchial Asthma': 9,
|
@@ -70,18 +74,22 @@ disease_dict = {
|
|
70 |
'Psoriasis': 39, 'Impetigo': 40
|
71 |
}
|
72 |
|
73 |
-
#
|
74 |
-
|
75 |
-
|
|
|
|
|
76 |
X = df.iloc[:, :-1] # Features
|
77 |
y = df.iloc[:, -1] # Target
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
81 |
|
82 |
-
# Preparing training and testing data
|
83 |
-
X_train, y_train
|
84 |
-
X_test, y_test
|
85 |
|
86 |
# Define the models
|
87 |
models = {
|
@@ -93,12 +101,13 @@ models = {
|
|
93 |
# Train and evaluate models
|
94 |
trained_models = {}
|
95 |
for model_name, model_obj in models.items():
|
96 |
-
model_obj.fit(X_train, y_train)
|
97 |
-
y_pred = model_obj.predict(X_test)
|
98 |
-
acc = accuracy_score(y_test, y_pred)
|
99 |
trained_models[model_name] = {'model': model_obj, 'accuracy': acc}
|
100 |
|
101 |
-
# Helper Functions for Chatbot
|
|
|
102 |
def bag_of_words(s, words):
|
103 |
"""Convert user input to bag-of-words vector."""
|
104 |
bag = [0] * len(words)
|
@@ -136,9 +145,8 @@ def analyze_sentiment(user_input):
|
|
136 |
return f"Sentiment: {sentiment_map[sentiment_class]}"
|
137 |
|
138 |
def detect_emotion(user_input):
|
139 |
-
"""Detect emotions based on input."""
|
140 |
-
|
141 |
-
result = pipe(user_input)
|
142 |
emotion = result[0]["label"].lower().strip()
|
143 |
emotion_map = {
|
144 |
"joy": "Joy 😊",
|
@@ -184,13 +192,15 @@ def generate_suggestions(emotion):
|
|
184 |
],
|
185 |
}
|
186 |
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
|
|
192 |
formatted_suggestions += [
|
193 |
-
f"| {title} | [{link}]({link}) |"
|
|
|
194 |
]
|
195 |
|
196 |
return "\n".join(formatted_suggestions)
|
@@ -200,7 +210,7 @@ def get_health_professionals_and_map(location, query):
|
|
200 |
try:
|
201 |
if not location or not query:
|
202 |
return [], "" # Return empty list if inputs are missing
|
203 |
-
|
204 |
geo_location = gmaps.geocode(location)
|
205 |
if geo_location:
|
206 |
lat, lng = geo_location[0]["geometry"]["location"].values()
|
@@ -214,9 +224,10 @@ def get_health_professionals_and_map(location, query):
|
|
214 |
popup=f"{place['name']}"
|
215 |
).add_to(map_)
|
216 |
return professionals, map_._repr_html_()
|
217 |
-
return [], ""
|
218 |
except Exception as e:
|
219 |
-
|
|
|
220 |
|
221 |
# Main Application Logic for Chatbot
|
222 |
def app_function_chatbot(user_input, location, query, history):
|
@@ -242,13 +253,15 @@ def predict_disease(symptoms):
|
|
242 |
predictions = {}
|
243 |
for model_name, info in trained_models.items():
|
244 |
prediction = info['model'].predict([input_test])[0]
|
245 |
-
predicted_disease =
|
246 |
predictions[model_name] = predicted_disease
|
247 |
|
248 |
# Create a Markdown table for displaying predictions
|
249 |
-
markdown_output = [
|
250 |
-
|
251 |
-
|
|
|
|
|
252 |
for model_name, disease in predictions.items():
|
253 |
markdown_output.append(f"| {model_name} | {disease} |")
|
254 |
|
@@ -275,15 +288,15 @@ welcome_message = """
|
|
275 |
margin: 20px 0;
|
276 |
}
|
277 |
.info-graphic img {
|
278 |
-
width: 150px;
|
279 |
-
height: auto;
|
280 |
-
margin: 0 10px;
|
281 |
}
|
282 |
h1 {
|
283 |
-
text-align: center;
|
284 |
-
font-size: 3em;
|
285 |
-
color: #004d40;
|
286 |
-
margin-bottom: 20px;
|
287 |
}
|
288 |
</style>
|
289 |
<div id="welcome-message">Welcome to the Well-Being Companion!</div>
|
@@ -325,27 +338,25 @@ with gr.Blocks(theme="shivi/calm_seafoam") as app:
|
|
325 |
location = gr.Textbox(label="Please Enter Your Current Location", placeholder="E.g., Honolulu", max_lines=1)
|
326 |
query = gr.Textbox(label="Search Health Professionals Nearby", placeholder="E.g., Health Professionals", max_lines=1)
|
327 |
|
328 |
-
with gr.Row():
|
329 |
submit_chatbot = gr.Button(value="Submit Your Message", variant="primary")
|
330 |
-
clear_chatbot = gr.Button(value="Clear", variant="secondary")
|
331 |
|
332 |
chatbot = gr.Chatbot(label="Chat History", show_label=True)
|
333 |
sentiment = gr.Textbox(label="Detected Sentiment", show_label=True)
|
334 |
emotion = gr.Textbox(label="Detected Emotion", show_label=True)
|
335 |
|
336 |
-
# Apply styles and create the DataFrame
|
337 |
professionals = gr.DataFrame(
|
338 |
-
label="Nearby Health Professionals",
|
339 |
headers=["Name", "Address"],
|
340 |
-
value=[]
|
341 |
)
|
342 |
|
343 |
suggestions_markdown = gr.Markdown(label="Suggestions")
|
344 |
map_html = gr.HTML(label="Interactive Map")
|
345 |
|
346 |
-
# Functionality to clear the chat input
|
347 |
def clear_input():
|
348 |
-
return "", []
|
349 |
|
350 |
submit_chatbot.click(
|
351 |
app_function_chatbot,
|
@@ -356,7 +367,7 @@ with gr.Blocks(theme="shivi/calm_seafoam") as app:
|
|
356 |
clear_chatbot.click(
|
357 |
clear_input,
|
358 |
inputs=None,
|
359 |
-
outputs=[user_input, chatbot]
|
360 |
)
|
361 |
|
362 |
with gr.Tab("Disease Prediction"):
|
@@ -385,11 +396,11 @@ with gr.Blocks(theme="shivi/calm_seafoam") as app:
|
|
385 |
disease_prediction_result = gr.Markdown(label="Predicted Diseases")
|
386 |
|
387 |
submit_disease.click(
|
388 |
-
lambda
|
389 |
-
[symptom1, symptom2, symptom3, symptom4, symptom5]),
|
390 |
inputs=[symptom1, symptom2, symptom3, symptom4, symptom5],
|
391 |
outputs=disease_prediction_result
|
392 |
)
|
393 |
|
394 |
# Launch the Gradio application
|
395 |
-
|
|
|
|
18 |
from sklearn.ensemble import RandomForestClassifier
|
19 |
from sklearn.naive_bayes import GaussianNB
|
20 |
from sklearn.metrics import accuracy_score
|
21 |
+
import logging
|
22 |
|
23 |
# Suppress TensorFlow warnings
|
24 |
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
|
|
50 |
tokenizer_emotion = AutoTokenizer.from_pretrained("j-hartmann/emotion-english-distilroberta-base")
|
51 |
model_emotion = AutoModelForSequenceClassification.from_pretrained("j-hartmann/emotion-english-distilroberta-base")
|
52 |
|
53 |
+
# Initialize emotion pipeline once
|
54 |
+
emotion_pipeline = pipeline("text-classification", model=model_emotion, tokenizer=tokenizer_emotion)
|
55 |
+
|
56 |
# Google Maps API Client
|
57 |
gmaps = googlemaps.Client(key=os.getenv("GOOGLE_API_KEY"))
|
58 |
|
59 |
# Load the disease dataset
|
60 |
df_train = pd.read_csv("Training.csv") # Change the file path as necessary
|
61 |
+
df_test = pd.read_csv("Testing.csv") # Change the file path as necessary
|
62 |
|
63 |
+
# Encode diseases dictionary (optional, currently unused directly)
|
64 |
disease_dict = {
|
65 |
'Fungal infection': 0, 'Allergy': 1, 'GERD': 2, 'Chronic cholestasis': 3, 'Drug Reaction': 4,
|
66 |
'Peptic ulcer disease': 5, 'AIDS': 6, 'Diabetes ': 7, 'Gastroenteritis': 8, 'Bronchial Asthma': 9,
|
|
|
74 |
'Psoriasis': 39, 'Impetigo': 40
|
75 |
}
|
76 |
|
77 |
+
# Label encoder for consistent train/test encoding
|
78 |
+
label_encoder = LabelEncoder()
|
79 |
+
|
80 |
+
def prepare_data(df, is_train=True):
|
81 |
+
"""Prepares data for training/testing with consistent label encoding."""
|
82 |
X = df.iloc[:, :-1] # Features
|
83 |
y = df.iloc[:, -1] # Target
|
84 |
+
if is_train:
|
85 |
+
y_encoded = label_encoder.fit_transform(y)
|
86 |
+
else:
|
87 |
+
y_encoded = label_encoder.transform(y)
|
88 |
+
return X, y_encoded
|
89 |
|
90 |
+
# Preparing training and testing data with the same label encoder
|
91 |
+
X_train, y_train = prepare_data(df_train, is_train=True)
|
92 |
+
X_test, y_test = prepare_data(df_test, is_train=False)
|
93 |
|
94 |
# Define the models
|
95 |
models = {
|
|
|
101 |
# Train and evaluate models
|
102 |
trained_models = {}
|
103 |
for model_name, model_obj in models.items():
|
104 |
+
model_obj.fit(X_train, y_train)
|
105 |
+
y_pred = model_obj.predict(X_test)
|
106 |
+
acc = accuracy_score(y_test, y_pred)
|
107 |
trained_models[model_name] = {'model': model_obj, 'accuracy': acc}
|
108 |
|
109 |
+
# --- Helper Functions for Chatbot ---
|
110 |
+
|
111 |
def bag_of_words(s, words):
|
112 |
"""Convert user input to bag-of-words vector."""
|
113 |
bag = [0] * len(words)
|
|
|
145 |
return f"Sentiment: {sentiment_map[sentiment_class]}"
|
146 |
|
147 |
def detect_emotion(user_input):
|
148 |
+
"""Detect emotions based on input using cached pipeline."""
|
149 |
+
result = emotion_pipeline(user_input)
|
|
|
150 |
emotion = result[0]["label"].lower().strip()
|
151 |
emotion_map = {
|
152 |
"joy": "Joy 😊",
|
|
|
192 |
],
|
193 |
}
|
194 |
|
195 |
+
formatted_suggestions = [
|
196 |
+
"### Suggestions",
|
197 |
+
f"Since you’re feeling {emotion}, you might find these links particularly helpful. Don’t hesitate to explore:",
|
198 |
+
"| Title | Link |",
|
199 |
+
"|-------|------|"
|
200 |
+
]
|
201 |
formatted_suggestions += [
|
202 |
+
f"| {title} | [{link}]({link}) |"
|
203 |
+
for title, link in suggestions.get(emotion_key, [("No specific suggestions available.", "#")])
|
204 |
]
|
205 |
|
206 |
return "\n".join(formatted_suggestions)
|
|
|
210 |
try:
|
211 |
if not location or not query:
|
212 |
return [], "" # Return empty list if inputs are missing
|
213 |
+
|
214 |
geo_location = gmaps.geocode(location)
|
215 |
if geo_location:
|
216 |
lat, lng = geo_location[0]["geometry"]["location"].values()
|
|
|
224 |
popup=f"{place['name']}"
|
225 |
).add_to(map_)
|
226 |
return professionals, map_._repr_html_()
|
227 |
+
return [], ""
|
228 |
except Exception as e:
|
229 |
+
logging.error(f"Error fetching health professionals: {e}")
|
230 |
+
return [], ""
|
231 |
|
232 |
# Main Application Logic for Chatbot
|
233 |
def app_function_chatbot(user_input, location, query, history):
|
|
|
253 |
predictions = {}
|
254 |
for model_name, info in trained_models.items():
|
255 |
prediction = info['model'].predict([input_test])[0]
|
256 |
+
predicted_disease = label_encoder.inverse_transform([prediction])[0]
|
257 |
predictions[model_name] = predicted_disease
|
258 |
|
259 |
# Create a Markdown table for displaying predictions
|
260 |
+
markdown_output = [
|
261 |
+
"### Predicted Diseases",
|
262 |
+
"| Model | Predicted Disease |",
|
263 |
+
"|-------|------------------|"
|
264 |
+
]
|
265 |
for model_name, disease in predictions.items():
|
266 |
markdown_output.append(f"| {model_name} | {disease} |")
|
267 |
|
|
|
288 |
margin: 20px 0;
|
289 |
}
|
290 |
.info-graphic img {
|
291 |
+
width: 150px;
|
292 |
+
height: auto;
|
293 |
+
margin: 0 10px;
|
294 |
}
|
295 |
h1 {
|
296 |
+
text-align: center;
|
297 |
+
font-size: 3em;
|
298 |
+
color: #004d40;
|
299 |
+
margin-bottom: 20px;
|
300 |
}
|
301 |
</style>
|
302 |
<div id="welcome-message">Welcome to the Well-Being Companion!</div>
|
|
|
338 |
location = gr.Textbox(label="Please Enter Your Current Location", placeholder="E.g., Honolulu", max_lines=1)
|
339 |
query = gr.Textbox(label="Search Health Professionals Nearby", placeholder="E.g., Health Professionals", max_lines=1)
|
340 |
|
341 |
+
with gr.Row():
|
342 |
submit_chatbot = gr.Button(value="Submit Your Message", variant="primary")
|
343 |
+
clear_chatbot = gr.Button(value="Clear", variant="secondary")
|
344 |
|
345 |
chatbot = gr.Chatbot(label="Chat History", show_label=True)
|
346 |
sentiment = gr.Textbox(label="Detected Sentiment", show_label=True)
|
347 |
emotion = gr.Textbox(label="Detected Emotion", show_label=True)
|
348 |
|
|
|
349 |
professionals = gr.DataFrame(
|
350 |
+
label="Nearby Health Professionals",
|
351 |
headers=["Name", "Address"],
|
352 |
+
value=[]
|
353 |
)
|
354 |
|
355 |
suggestions_markdown = gr.Markdown(label="Suggestions")
|
356 |
map_html = gr.HTML(label="Interactive Map")
|
357 |
|
|
|
358 |
def clear_input():
|
359 |
+
return "", []
|
360 |
|
361 |
submit_chatbot.click(
|
362 |
app_function_chatbot,
|
|
|
367 |
clear_chatbot.click(
|
368 |
clear_input,
|
369 |
inputs=None,
|
370 |
+
outputs=[user_input, chatbot]
|
371 |
)
|
372 |
|
373 |
with gr.Tab("Disease Prediction"):
|
|
|
396 |
disease_prediction_result = gr.Markdown(label="Predicted Diseases")
|
397 |
|
398 |
submit_disease.click(
|
399 |
+
lambda sym1, sym2, sym3, sym4, sym5: predict_disease([sym1, sym2, sym3, sym4, sym5]),
|
|
|
400 |
inputs=[symptom1, symptom2, symptom3, symptom4, symptom5],
|
401 |
outputs=disease_prediction_result
|
402 |
)
|
403 |
|
404 |
# Launch the Gradio application
|
405 |
+
if __name__ == "__main__":
|
406 |
+
app.launch(share=True)
|