Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import matplotlib | |
matplotlib.use('Agg') | |
import matplotlib.pyplot as plt | |
from sklearn.manifold import TSNE | |
from sklearn.cluster import MiniBatchKMeans | |
from sentence_transformers import SentenceTransformer | |
from umap import UMAP | |
from joblib import Parallel, delayed | |
from functools import lru_cache | |
import io | |
from huggingface_hub import InferenceClient | |
# ---- Precomputed Elements ---- | |
file_path = 'symbipredict_2022_filtered.csv' | |
df = pd.read_csv(file_path) | |
model = SentenceTransformer("all-MiniLM-L6-v2") | |
embedding_arr = model.encode(df['symptoms']).astype(np.float32) | |
# Clustering | |
optimal_n_clusters = 41 | |
kmeans = MiniBatchKMeans(n_clusters=optimal_n_clusters, | |
batch_size=1024, | |
random_state=42) | |
kmeans_labels = kmeans.fit_predict(embedding_arr) | |
# Dimensionality Reduction | |
umap = UMAP(n_components=2, random_state=42) | |
embedding_umap = umap.fit_transform(embedding_arr) | |
# Precomputed Mappings | |
cluster_prognosis_map = (df.groupby('cluster')['prognosis'] | |
.unique().to_dict()) | |
# ---- Cached Functions ---- | |
def cached_encode(text): | |
return model.encode([text], convert_to_numpy=True)[0] | |
# ---- Optimized Plotting ---- | |
fig = plt.figure(figsize=(14, 10)) | |
ax = fig.add_subplot(111) | |
cmap = plt.get_cmap('tab20', optimal_n_clusters) | |
def create_plot(message, query_embedding): | |
ax.clear() | |
# Plot all points | |
ax.scatter(embedding_umap[:, 0], embedding_umap[:, 1], | |
c=kmeans_labels, cmap=cmap, | |
edgecolor='k', linewidth=0.5, alpha=0.7) | |
# Plot query | |
query_umap = umap.transform(query_embedding.reshape(1, -1)) | |
ax.scatter(query_umap[:, 0], query_umap[:, 1], | |
c='red', marker='X', s=200, | |
label=f'Query: {message}') | |
# Finalize plot | |
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left') | |
ax.set_title("Medical Condition Clustering") | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', bbox_inches='tight') | |
buf.seek(0) | |
return buf | |
# ---- Response Function ---- | |
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") | |
async def respond(message, history, system_message, max_tokens, temperature, top_p): | |
try: | |
# Encoding | |
query_embedding = cached_encode(message) | |
# Cluster prediction | |
query_cluster = kmeans.predict(query_embedding.reshape(1, -1))[0] | |
# Parallel plot generation | |
plot_buf = await asyncio.to_thread( | |
create_plot, message, query_embedding | |
) | |
# Async LLM response | |
llm_response = await client.chat_completion( | |
[{"role": "user", "content": message}], | |
max_tokens=max_tokens, | |
stream=False, | |
temperature=temperature, | |
top_p=top_p | |
) | |
# Combine responses | |
full_response = ( | |
f"{llm_response}\n\nCluster {query_cluster} contains: " | |
f"{', '.join(cluster_prognosis_map[query_cluster])}" | |
) | |
return full_response, plot_buf | |
except Exception as e: | |
return f"Error: {str(e)}", None | |
# ---- Gradio Interface ---- | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox(value="Medical diagnosis assistant", label="System Role"), | |
gr.Slider(512, 2048, value=512, step=128, label="Max Tokens"), | |
gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider(0.5, 1.0, value=0.95, step=0.05, label="Top-p") | |
] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |