Spaces:
Sleeping
Sleeping
import gradio as gr | |
import tensorflow as tf | |
from transformers import TFAutoModel, AutoTokenizer | |
import numpy as np | |
import shap | |
from scipy.special import softmax | |
# Model and Tokenizer Setup | |
MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = TFAutoModel.from_pretrained(MODEL_NAME) | |
# Constants | |
SEQ_LEN = 128 | |
CONDITIONS = [ | |
"Common Cold", "COVID-19", "Allergies", "Anxiety Disorder", "Skin Infection", | |
"Heart Condition", "Digestive Issues", "Migraine", "Muscle Strain", "Arthritis" | |
] | |
# Dynamic Condition Predictions | |
def predict_condition(description: str): | |
tokens = tokenizer( | |
description, max_length=SEQ_LEN, truncation=True, padding="max_length", return_tensors="tf" | |
) | |
outputs = model(tokens).last_hidden_state[:, 0, :] # CLS token output | |
scores = softmax(outputs.numpy()) | |
predictions = dict(zip(CONDITIONS, scores.flatten())) | |
return predictions | |
# Lifestyle Tips | |
LIFESTYLE_TIPS = { | |
"Common Cold": "Rest, stay hydrated, and use saline nasal sprays.", | |
"COVID-19": "Quarantine, stay hydrated, and seek medical attention if symptoms worsen.", | |
"Allergies": "Avoid allergens, take antihistamines, and use air purifiers.", | |
"Anxiety Disorder": "Practice mindfulness, exercise, and seek therapy if needed.", | |
"Skin Infection": "Keep the area clean, use topical creams, and consult a dermatologist.", | |
# Add more conditions and tips... | |
} | |
def get_lifestyle_advice(condition: str): | |
return LIFESTYLE_TIPS.get(condition, "Consult a healthcare professional for guidance.") | |
# Interactive Health Visualization (SHAP) | |
def explain_prediction(text: str): | |
explainer = shap.Explainer(lambda x: predict_condition(x), tokenizer) | |
shap_values = explainer([text]) | |
return shap.plots.text(shap_values, display=False) | |
# Symptom Tracker (Simple Implementation) | |
symptom_history = [] | |
def log_symptom(symptom: str): | |
symptom_history.append(symptom) | |
return f"Logged: {symptom}. Total symptoms logged: {len(symptom_history)}" | |
def display_symptom_trends(): | |
return "\n".join(symptom_history[-10:]) # Last 10 logged symptoms | |
# Gradio UI Design | |
css = """ | |
textarea { background-color: transparent; border: 1px solid #6366f1; } | |
""" | |
with gr.Blocks(title="MedAI Compass", css=css, theme=gr.themes.Soft()) as app: | |
# Header | |
gr.HTML("<h1>MedAI Compass: Comprehensive Symptom and Health Guide</h1>") | |
# Section: Symptom Diagnosis | |
with gr.Row(): | |
gr.Markdown("## Symptom Diagnosis") | |
input_description = gr.Textbox(label="Describe your symptom") | |
diagnose_btn = gr.Button("Diagnose") | |
diagnosis_output = gr.Label(label="Possible Conditions") | |
diagnose_btn.click(predict_condition, inputs=input_description, outputs=diagnosis_output) | |
# Section: SHAP Analysis | |
with gr.Row(): | |
gr.Markdown("## Explain Predictions") | |
shap_text_input = gr.Textbox(label="Enter Symptom Description for Analysis") | |
shap_btn = gr.Button("Generate Explanation") | |
shap_output = gr.HTML() | |
shap_btn.click(explain_prediction, inputs=shap_text_input, outputs=shap_output) | |
# Section: Personalized Advice | |
with gr.Row(): | |
gr.Markdown("## Personalized Health Advice") | |
condition_input = gr.Dropdown(choices=CONDITIONS, label="Select a Condition") | |
advice_output = gr.Textbox(label="Advice") | |
advice_btn = gr.Button("Get Advice") | |
advice_btn.click(get_lifestyle_advice, inputs=condition_input, outputs=advice_output) | |
# Section: Symptom Tracker | |
with gr.Row(): | |
gr.Markdown("## Symptom Tracker") | |
tracker_input = gr.Textbox(label="Log a Symptom") | |
tracker_btn = gr.Button("Log Symptom") | |
tracker_output = gr.Textbox(label="Logged Symptoms") | |
tracker_btn.click(log_symptom, inputs=tracker_input, outputs=tracker_output) | |
tracker_display_btn = gr.Button("Display Trends") | |
tracker_trends_output = gr.Textbox(label="Symptom Trends") | |
tracker_display_btn.click(display_symptom_trends, outputs=tracker_trends_output) | |
# Footer | |
gr.HTML("<p>© 2024 MedAI Compass. All Rights Reserved.</p>") | |
app.launch() | |