Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
from sklearn.cluster import MiniBatchKMeans | |
from sentence_transformers import SentenceTransformer | |
from functools import lru_cache | |
from huggingface_hub import AsyncInferenceClient | |
from sklearn.decomposition import PCA | |
# ------ Data Loading ------ | |
df = pd.read_csv("symbipredict_2022_filtered.csv") | |
# ------ Model Initialization ------ | |
model = SentenceTransformer("all-MiniLM-L6-v2") | |
embedding_arr = model.encode(df['symptoms']).astype(np.float32) | |
# ------ Clustering Setup ------ | |
kmeans = MiniBatchKMeans(n_clusters=10, random_state=42) | |
cluster_labels = kmeans.fit_predict(embedding_arr) | |
cluster_prognosis_map = df.groupby(cluster_labels)['prognosis'].agg( | |
lambda x: x.value_counts(normalize=True).head(3).to_dict() | |
) | |
# ------ PCA Initialization ------ | |
pca = PCA(n_components=2).fit(embedding_arr) | |
# ------ Cached Functions ------ | |
def cached_encode(text): | |
return model.encode(text, convert_to_numpy=True) | |
# ------ Async Inference Client ------ | |
client = AsyncInferenceClient("HuggingFaceH4/zephyr-7b-beta") | |
# ------ Streaming Response Function ------ | |
async def respond(message, history, system_message): | |
try: | |
# Encoding and clustering | |
query_embedding = cached_encode(message) | |
query_cluster = kmeans.predict(query_embedding.reshape(1, -1))[0] | |
# Generate streaming response | |
stream = await client.chat_completion( | |
messages=[{ | |
"role": "system", | |
"content": system_message | |
}, { | |
"role": "user", | |
"content": message | |
}], | |
max_tokens=512, | |
stream=True, | |
temperature=0.7, | |
top_p=0.95 | |
) | |
full_response = "" | |
async for chunk in stream: | |
content = chunk.choices[0].delta.content | |
if content: | |
full_response += content | |
# Format prognosis likelihoods | |
prognosis_info = cluster_prognosis_map[query_cluster] | |
formatted_prognoses = ", ".join([f"{k}: {v*100:.1f}%" for k, v in prognosis_info.items()]) | |
# Append formatted prognosis to response | |
yield f"{full_response}\n\nCommon prognoses: {formatted_prognoses}" | |
except Exception as e: | |
yield f"Error: {str(e)}" | |
# ------ Gradio Interface ------ | |
with gr.Blocks() as demo: | |
gr.Markdown("# Medical Diagnosis Assistant") | |
chatbot = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox(value="Medical diagnosis assistant", label="System Role") | |
], | |
examples=[["I have a headache and fever"]] # Example query for guidance | |
) | |
if __name__ == "__main__": | |
demo.launch(max_threads=10) | |