import gradio as gr import numpy as np import pandas as pd from sklearn.cluster import MiniBatchKMeans from sentence_transformers import SentenceTransformer from functools import lru_cache from huggingface_hub import AsyncInferenceClient from sklearn.decomposition import PCA # ------ Data Loading ------ df = pd.read_csv("symbipredict_2022_filtered.csv") # ------ Model Initialization ------ model = SentenceTransformer("all-MiniLM-L6-v2") embedding_arr = model.encode(df['symptoms']).astype(np.float32) # ------ Clustering Setup ------ kmeans = MiniBatchKMeans(n_clusters=10, random_state=42) cluster_labels = kmeans.fit_predict(embedding_arr) cluster_prognosis_map = df.groupby(cluster_labels)['prognosis'].agg( lambda x: x.value_counts(normalize=True).head(3).to_dict() ) # ------ PCA Initialization ------ pca = PCA(n_components=2).fit(embedding_arr) # ------ Cached Functions ------ @lru_cache(maxsize=100) def cached_encode(text): return model.encode(text, convert_to_numpy=True) # ------ Async Inference Client ------ client = AsyncInferenceClient("HuggingFaceH4/zephyr-7b-beta") # ------ Streaming Response Function ------ async def respond(message, history, system_message): try: # Encoding and clustering query_embedding = cached_encode(message) query_cluster = kmeans.predict(query_embedding.reshape(1, -1))[0] # Generate streaming response stream = await client.chat_completion( messages=[{ "role": "system", "content": system_message }, { "role": "user", "content": message }], max_tokens=512, stream=True, temperature=0.7, top_p=0.95 ) full_response = "" async for chunk in stream: content = chunk.choices[0].delta.content if content: full_response += content # Format prognosis likelihoods prognosis_info = cluster_prognosis_map[query_cluster] formatted_prognoses = ", ".join([f"{k}: {v*100:.1f}%" for k, v in prognosis_info.items()]) # Append formatted prognosis to response yield f"{full_response}\n\nCommon prognoses: {formatted_prognoses}" except Exception as e: yield f"Error: {str(e)}" # ------ Gradio Interface ------ with gr.Blocks() as demo: gr.Markdown("# Medical Diagnosis Assistant") chatbot = gr.ChatInterface( respond, additional_inputs=[ gr.Textbox(value="Medical diagnosis assistant", label="System Role") ], examples=[["I have a headache and fever"]] # Example query for guidance ) if __name__ == "__main__": demo.launch(max_threads=10)