Spaces:
Sleeping
Sleeping
File size: 3,646 Bytes
ee89cd3 c03c66e cd072df c03c66e cd072df c03c66e cd072df c03c66e ee89cd3 cd072df c03c66e cd072df ee89cd3 cd072df c03c66e cd072df c03c66e ee89cd3 cd072df c03c66e cd072df c03c66e cd072df ee89cd3 cd072df c03c66e cd072df c03c66e cd072df c03c66e cd072df c03c66e cd072df c03c66e cd072df ee89cd3 cd072df ee89cd3 cd072df c03c66e cd072df ee89cd3 cd072df 7b8e63b ee89cd3 c03c66e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import gradio as gr
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.cluster import MiniBatchKMeans
from sentence_transformers import SentenceTransformer
from umap import UMAP
from joblib import Parallel, delayed
from functools import lru_cache
import io
from huggingface_hub import InferenceClient
# ---- Precomputed Elements ----
file_path = 'symbipredict_2022_filtered.csv'
df = pd.read_csv(file_path)
model = SentenceTransformer("all-MiniLM-L6-v2")
embedding_arr = model.encode(df['symptoms']).astype(np.float32)
# Clustering
optimal_n_clusters = 41
kmeans = MiniBatchKMeans(n_clusters=optimal_n_clusters,
batch_size=1024,
random_state=42)
kmeans_labels = kmeans.fit_predict(embedding_arr)
# Dimensionality Reduction
umap = UMAP(n_components=2, random_state=42)
embedding_umap = umap.fit_transform(embedding_arr)
# Precomputed Mappings
cluster_prognosis_map = (df.groupby('cluster')['prognosis']
.unique().to_dict())
# ---- Cached Functions ----
@lru_cache(maxsize=100)
def cached_encode(text):
return model.encode([text], convert_to_numpy=True)[0]
# ---- Optimized Plotting ----
fig = plt.figure(figsize=(14, 10))
ax = fig.add_subplot(111)
cmap = plt.get_cmap('tab20', optimal_n_clusters)
def create_plot(message, query_embedding):
ax.clear()
# Plot all points
ax.scatter(embedding_umap[:, 0], embedding_umap[:, 1],
c=kmeans_labels, cmap=cmap,
edgecolor='k', linewidth=0.5, alpha=0.7)
# Plot query
query_umap = umap.transform(query_embedding.reshape(1, -1))
ax.scatter(query_umap[:, 0], query_umap[:, 1],
c='red', marker='X', s=200,
label=f'Query: {message}')
# Finalize plot
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax.set_title("Medical Condition Clustering")
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
return buf
# ---- Response Function ----
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
async def respond(message, history, system_message, max_tokens, temperature, top_p):
try:
# Encoding
query_embedding = cached_encode(message)
# Cluster prediction
query_cluster = kmeans.predict(query_embedding.reshape(1, -1))[0]
# Parallel plot generation
plot_buf = await asyncio.to_thread(
create_plot, message, query_embedding
)
# Async LLM response
llm_response = await client.chat_completion(
[{"role": "user", "content": message}],
max_tokens=max_tokens,
stream=False,
temperature=temperature,
top_p=top_p
)
# Combine responses
full_response = (
f"{llm_response}\n\nCluster {query_cluster} contains: "
f"{', '.join(cluster_prognosis_map[query_cluster])}"
)
return full_response, plot_buf
except Exception as e:
return f"Error: {str(e)}", None
# ---- Gradio Interface ----
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="Medical diagnosis assistant", label="System Role"),
gr.Slider(512, 2048, value=512, step=128, label="Max Tokens"),
gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(0.5, 1.0, value=0.95, step=0.05, label="Top-p")
]
)
if __name__ == "__main__":
demo.launch()
|