SymptomsSearch / app.py
AAA1988's picture
Update app.py
2657bef verified
import gradio as gr
import numpy as np
import pandas as pd
from sklearn.cluster import MiniBatchKMeans
from sentence_transformers import SentenceTransformer
from functools import lru_cache
from huggingface_hub import AsyncInferenceClient
from sklearn.decomposition import PCA
# ------ Data Loading ------
df = pd.read_csv("symbipredict_2022_filtered.csv")
# ------ Model Initialization ------
model = SentenceTransformer("all-MiniLM-L6-v2")
embedding_arr = model.encode(df['symptoms']).astype(np.float32)
# ------ Clustering Setup ------
kmeans = MiniBatchKMeans(n_clusters=10, random_state=42)
cluster_labels = kmeans.fit_predict(embedding_arr)
cluster_prognosis_map = df.groupby(cluster_labels)['prognosis'].agg(
lambda x: x.value_counts(normalize=True).head(3).to_dict()
)
# ------ PCA Initialization ------
pca = PCA(n_components=2).fit(embedding_arr)
# ------ Cached Functions ------
@lru_cache(maxsize=100)
def cached_encode(text):
return model.encode(text, convert_to_numpy=True)
# ------ Async Inference Client ------
client = AsyncInferenceClient("HuggingFaceH4/zephyr-7b-beta")
# ------ Streaming Response Function ------
async def respond(message, history, system_message):
try:
# Encoding and clustering
query_embedding = cached_encode(message)
query_cluster = kmeans.predict(query_embedding.reshape(1, -1))[0]
# Generate streaming response
stream = await client.chat_completion(
messages=[{
"role": "system",
"content": system_message
}, {
"role": "user",
"content": message
}],
max_tokens=512,
stream=True,
temperature=0.7,
top_p=0.95
)
full_response = ""
async for chunk in stream:
content = chunk.choices[0].delta.content
if content:
full_response += content
# Format prognosis likelihoods
prognosis_info = cluster_prognosis_map[query_cluster]
formatted_prognoses = ", ".join([f"{k}: {v*100:.1f}%" for k, v in prognosis_info.items()])
# Append formatted prognosis to response
yield f"{full_response}\n\nCommon prognoses: {formatted_prognoses}"
except Exception as e:
yield f"Error: {str(e)}"
# ------ Gradio Interface ------
with gr.Blocks() as demo:
gr.Markdown("# Medical Diagnosis Assistant")
chatbot = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="Medical diagnosis assistant", label="System Role")
],
examples=[["I have a headache and fever"]] # Example query for guidance
)
if __name__ == "__main__":
demo.launch(max_threads=10)