SymptomsSearch / app.py
AAA1988's picture
Update app.py
c03c66e verified
raw
history blame
4.95 kB
import gradio as gr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from sentence_transformers import SentenceTransformer
import io
from huggingface_hub import InferenceClient
# Load the dataset
file_path = 'symbipredict_2022_filtered.csv' # Ensure this file is uploaded to the Space
df = pd.read_csv(file_path)
# Load the model from the local directory
model_path = "all-MiniLM-L6-v2" # Ensure this directory is uploaded to the Space
model = SentenceTransformer(model_path)
# Embed vectors
embedding_arr = model.encode(df['symptoms'])
# Apply K-Means with the optimal number of clusters (41 clusters)
optimal_n_clusters = 41
kmeans = KMeans(n_clusters=optimal_n_clusters, random_state=42)
kmeans_labels = kmeans.fit_predict(embedding_arr)
# Create a DataFrame with prognosis and their corresponding clusters
cluster_prognosis_mapping = pd.DataFrame({'prognosis': df['prognosis'], 'cluster': kmeans_labels})
# Get the unique cluster-prognosis pairs
unique_clusters = cluster_prognosis_mapping.drop_duplicates().sort_values(by='cluster')
# Initialize the Inference Client
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
def respond(message, history, system_message, max_tokens, temperature, top_p):
query_embedding = model.encode([message])[0]
# Combine embeddings with the query embedding
combined_embeddings = np.vstack([embedding_arr, query_embedding])
# Apply t-SNE to the combined embeddings
tsne = TSNE(n_components=2, perplexity=30, n_iter=1000, random_state=42)
embedding_tsne = tsne.fit_transform(combined_embeddings)
# Separate the transformed query embedding from the rest
embedding_tsne_query = embedding_tsne[-1]
embedding_tsne = embedding_tsne[:-1]
# Plot data along t-SNE components with the query
plt.figure(figsize=(14, 10))
plt.rcParams.update({'font.size': 16})
plt.grid()
# Use a colormap for different clusters
cmap = plt.get_cmap('tab20', optimal_n_clusters)
# Highlight the cluster to which the query embedding belongs
query_cluster = kmeans.predict(query_embedding.reshape(1, -1))[0]
highlight_cluster = query_cluster
c = 0
for prognosis in df['prognosis'].unique():
idx = np.where(df['prognosis'] == prognosis)
if kmeans.predict(embedding_arr[idx])[0] == highlight_cluster:
plt.scatter(embedding_tsne[idx, 0], embedding_tsne[idx, 1], c=[cmap(c)] * len(idx[0]), edgecolor='black', linewidth=1, label=f'{prognosis} (Cluster {highlight_cluster})')
else:
plt.scatter(embedding_tsne[idx, 0], embedding_tsne[idx, 1], c=[cmap(c)] * len(idx[0]), label=prognosis)
c = c + 1 / len(df['prognosis'].unique())
plt.scatter(embedding_tsne_query[0], embedding_tsne_query[1], c='k', marker='D', s=200, label='query')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.xticks(rotation=45)
plt.xlabel("t-SNE Component 1")
plt.ylabel("t-SNE Component 2")
plt.title(f'Query: "{message}" (Belongs to Cluster {highlight_cluster})')
# Save the plot to a bytes buffer
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plt.close()
# Generate the text response using the Inference Client
messages = [{"role": "system", "content": system_message}]
for user_msg, bot_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if bot_msg:
messages.append({"role": "assistant", "content": bot_msg})
messages.append({"role": "user", "content": message})
response_text = ""
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
response_text += token
prognosis_summary = unique_clusters[unique_clusters['cluster'] == highlight_cluster]['prognosis'].tolist()
response_text += f"\nThe query belongs to cluster {highlight_cluster} which includes the following prognosis: {', '.join(prognosis_summary)}."
return {"response": response_text, "image": buf}
# Set up the Gradio Chat Interface
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
outputs=[gr.Textbox(label="Prognosis Summary"), gr.Image(label="t-SNE Plot")]
)
if __name__ == "__main__":
demo.launch()