SymptomsSearch / app.py
AAA1988's picture
Update app.py
cd072df verified
raw
history blame
3.65 kB
import gradio as gr
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.cluster import MiniBatchKMeans
from sentence_transformers import SentenceTransformer
from umap import UMAP
from joblib import Parallel, delayed
from functools import lru_cache
import io
from huggingface_hub import InferenceClient
# ---- Precomputed Elements ----
file_path = 'symbipredict_2022_filtered.csv'
df = pd.read_csv(file_path)
model = SentenceTransformer("all-MiniLM-L6-v2")
embedding_arr = model.encode(df['symptoms']).astype(np.float32)
# Clustering
optimal_n_clusters = 41
kmeans = MiniBatchKMeans(n_clusters=optimal_n_clusters,
batch_size=1024,
random_state=42)
kmeans_labels = kmeans.fit_predict(embedding_arr)
# Dimensionality Reduction
umap = UMAP(n_components=2, random_state=42)
embedding_umap = umap.fit_transform(embedding_arr)
# Precomputed Mappings
cluster_prognosis_map = (df.groupby('cluster')['prognosis']
.unique().to_dict())
# ---- Cached Functions ----
@lru_cache(maxsize=100)
def cached_encode(text):
return model.encode([text], convert_to_numpy=True)[0]
# ---- Optimized Plotting ----
fig = plt.figure(figsize=(14, 10))
ax = fig.add_subplot(111)
cmap = plt.get_cmap('tab20', optimal_n_clusters)
def create_plot(message, query_embedding):
ax.clear()
# Plot all points
ax.scatter(embedding_umap[:, 0], embedding_umap[:, 1],
c=kmeans_labels, cmap=cmap,
edgecolor='k', linewidth=0.5, alpha=0.7)
# Plot query
query_umap = umap.transform(query_embedding.reshape(1, -1))
ax.scatter(query_umap[:, 0], query_umap[:, 1],
c='red', marker='X', s=200,
label=f'Query: {message}')
# Finalize plot
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax.set_title("Medical Condition Clustering")
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
return buf
# ---- Response Function ----
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
async def respond(message, history, system_message, max_tokens, temperature, top_p):
try:
# Encoding
query_embedding = cached_encode(message)
# Cluster prediction
query_cluster = kmeans.predict(query_embedding.reshape(1, -1))[0]
# Parallel plot generation
plot_buf = await asyncio.to_thread(
create_plot, message, query_embedding
)
# Async LLM response
llm_response = await client.chat_completion(
[{"role": "user", "content": message}],
max_tokens=max_tokens,
stream=False,
temperature=temperature,
top_p=top_p
)
# Combine responses
full_response = (
f"{llm_response}\n\nCluster {query_cluster} contains: "
f"{', '.join(cluster_prognosis_map[query_cluster])}"
)
return full_response, plot_buf
except Exception as e:
return f"Error: {str(e)}", None
# ---- Gradio Interface ----
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="Medical diagnosis assistant", label="System Role"),
gr.Slider(512, 2048, value=512, step=128, label="Max Tokens"),
gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(0.5, 1.0, value=0.95, step=0.05, label="Top-p")
]
)
if __name__ == "__main__":
demo.launch()