FlinShaHealth / app.py
mgbam's picture
Upload app.py
dfff6b7 verified
raw
history blame
4.2 kB
import gradio as gr
import tensorflow as tf
from transformers import TFAutoModel, AutoTokenizer
import numpy as np
import shap
from scipy.special import softmax
# Model and Tokenizer Setup
MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = TFAutoModel.from_pretrained(MODEL_NAME)
# Constants
SEQ_LEN = 128
CONDITIONS = [
"Common Cold", "COVID-19", "Allergies", "Anxiety Disorder", "Skin Infection",
"Heart Condition", "Digestive Issues", "Migraine", "Muscle Strain", "Arthritis"
]
# Dynamic Condition Predictions
def predict_condition(description: str):
tokens = tokenizer(
description, max_length=SEQ_LEN, truncation=True, padding="max_length", return_tensors="tf"
)
outputs = model(tokens).last_hidden_state[:, 0, :] # CLS token output
scores = softmax(outputs.numpy())
predictions = dict(zip(CONDITIONS, scores.flatten()))
return predictions
# Lifestyle Tips
LIFESTYLE_TIPS = {
"Common Cold": "Rest, stay hydrated, and use saline nasal sprays.",
"COVID-19": "Quarantine, stay hydrated, and seek medical attention if symptoms worsen.",
"Allergies": "Avoid allergens, take antihistamines, and use air purifiers.",
"Anxiety Disorder": "Practice mindfulness, exercise, and seek therapy if needed.",
"Skin Infection": "Keep the area clean, use topical creams, and consult a dermatologist.",
# Add more conditions and tips...
}
def get_lifestyle_advice(condition: str):
return LIFESTYLE_TIPS.get(condition, "Consult a healthcare professional for guidance.")
# Interactive Health Visualization (SHAP)
def explain_prediction(text: str):
explainer = shap.Explainer(lambda x: predict_condition(x), tokenizer)
shap_values = explainer([text])
return shap.plots.text(shap_values, display=False)
# Symptom Tracker (Simple Implementation)
symptom_history = []
def log_symptom(symptom: str):
symptom_history.append(symptom)
return f"Logged: {symptom}. Total symptoms logged: {len(symptom_history)}"
def display_symptom_trends():
return "\n".join(symptom_history[-10:]) # Last 10 logged symptoms
# Gradio UI Design
css = """
textarea { background-color: transparent; border: 1px solid #6366f1; }
"""
with gr.Blocks(title="MedAI Compass", css=css, theme=gr.themes.Soft()) as app:
# Header
gr.HTML("<h1>MedAI Compass: Comprehensive Symptom and Health Guide</h1>")
# Section: Symptom Diagnosis
with gr.Row():
gr.Markdown("## Symptom Diagnosis")
input_description = gr.Textbox(label="Describe your symptom")
diagnose_btn = gr.Button("Diagnose")
diagnosis_output = gr.Label(label="Possible Conditions")
diagnose_btn.click(predict_condition, inputs=input_description, outputs=diagnosis_output)
# Section: SHAP Analysis
with gr.Row():
gr.Markdown("## Explain Predictions")
shap_text_input = gr.Textbox(label="Enter Symptom Description for Analysis")
shap_btn = gr.Button("Generate Explanation")
shap_output = gr.HTML()
shap_btn.click(explain_prediction, inputs=shap_text_input, outputs=shap_output)
# Section: Personalized Advice
with gr.Row():
gr.Markdown("## Personalized Health Advice")
condition_input = gr.Dropdown(choices=CONDITIONS, label="Select a Condition")
advice_output = gr.Textbox(label="Advice")
advice_btn = gr.Button("Get Advice")
advice_btn.click(get_lifestyle_advice, inputs=condition_input, outputs=advice_output)
# Section: Symptom Tracker
with gr.Row():
gr.Markdown("## Symptom Tracker")
tracker_input = gr.Textbox(label="Log a Symptom")
tracker_btn = gr.Button("Log Symptom")
tracker_output = gr.Textbox(label="Logged Symptoms")
tracker_btn.click(log_symptom, inputs=tracker_input, outputs=tracker_output)
tracker_display_btn = gr.Button("Display Trends")
tracker_trends_output = gr.Textbox(label="Symptom Trends")
tracker_display_btn.click(display_symptom_trends, outputs=tracker_trends_output)
# Footer
gr.HTML("<p>© 2024 MedAI Compass. All Rights Reserved.</p>")
app.launch()