Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from sklearn.manifold import TSNE | |
from sklearn.cluster import KMeans | |
from sentence_transformers import SentenceTransformer | |
import io | |
from huggingface_hub import InferenceClient | |
# Load the dataset | |
file_path = 'symbipredict_2022_filtered.csv' # Ensure this file is uploaded to the Space | |
df = pd.read_csv(file_path) | |
# Load the model from the local directory | |
model_path = "all-MiniLM-L6-v2" # Ensure this directory is uploaded to the Space | |
model = SentenceTransformer(model_path) | |
# Embed vectors | |
embedding_arr = model.encode(df['symptoms']) | |
# Apply K-Means with the optimal number of clusters (41 clusters) | |
optimal_n_clusters = 41 | |
kmeans = KMeans(n_clusters=optimal_n_clusters, random_state=42) | |
kmeans_labels = kmeans.fit_predict(embedding_arr) | |
# Create a DataFrame with prognosis and their corresponding clusters | |
cluster_prognosis_mapping = pd.DataFrame({'prognosis': df['prognosis'], 'cluster': kmeans_labels}) | |
# Get the unique cluster-prognosis pairs | |
unique_clusters = cluster_prognosis_mapping.drop_duplicates().sort_values(by='cluster') | |
# Initialize the Inference Client | |
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") | |
def respond(message, history, system_message, max_tokens, temperature, top_p): | |
query_embedding = model.encode([message])[0] | |
# Combine embeddings with the query embedding | |
combined_embeddings = np.vstack([embedding_arr, query_embedding]) | |
# Apply t-SNE to the combined embeddings | |
tsne = TSNE(n_components=2, perplexity=30, n_iter=1000, random_state=42) | |
embedding_tsne = tsne.fit_transform(combined_embeddings) | |
# Separate the transformed query embedding from the rest | |
embedding_tsne_query = embedding_tsne[-1] | |
embedding_tsne = embedding_tsne[:-1] | |
# Plot data along t-SNE components with the query | |
plt.figure(figsize=(14, 10)) | |
plt.rcParams.update({'font.size': 16}) | |
plt.grid() | |
# Use a colormap for different clusters | |
cmap = plt.get_cmap('tab20', optimal_n_clusters) | |
# Highlight the cluster to which the query embedding belongs | |
query_cluster = kmeans.predict(query_embedding.reshape(1, -1))[0] | |
highlight_cluster = query_cluster | |
c = 0 | |
for prognosis in df['prognosis'].unique(): | |
idx = np.where(df['prognosis'] == prognosis) | |
if kmeans.predict(embedding_arr[idx])[0] == highlight_cluster: | |
plt.scatter(embedding_tsne[idx, 0], embedding_tsne[idx, 1], c=[cmap(c)] * len(idx[0]), edgecolor='black', linewidth=1, label=f'{prognosis} (Cluster {highlight_cluster})') | |
else: | |
plt.scatter(embedding_tsne[idx, 0], embedding_tsne[idx, 1], c=[cmap(c)] * len(idx[0]), label=prognosis) | |
c = c + 1 / len(df['prognosis'].unique()) | |
plt.scatter(embedding_tsne_query[0], embedding_tsne_query[1], c='k', marker='D', s=200, label='query') | |
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') | |
plt.xticks(rotation=45) | |
plt.xlabel("t-SNE Component 1") | |
plt.ylabel("t-SNE Component 2") | |
plt.title(f'Query: "{message}" (Belongs to Cluster {highlight_cluster})') | |
# Save the plot to a bytes buffer | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png') | |
buf.seek(0) | |
plt.close() | |
# Generate the text response using the Inference Client | |
messages = [{"role": "system", "content": system_message}] | |
for user_msg, bot_msg in history: | |
if user_msg: | |
messages.append({"role": "user", "content": user_msg}) | |
if bot_msg: | |
messages.append({"role": "assistant", "content": bot_msg}) | |
messages.append({"role": "user", "content": message}) | |
response_text = "" | |
for message in client.chat_completion( | |
messages, | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
): | |
token = message.choices[0].delta.content | |
response_text += token | |
prognosis_summary = unique_clusters[unique_clusters['cluster'] == highlight_cluster]['prognosis'].tolist() | |
response_text += f"\nThe query belongs to cluster {highlight_cluster} which includes the following prognosis: {', '.join(prognosis_summary)}." | |
return {"response": response_text, "image": buf} | |
# Set up the Gradio Chat Interface | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox(value="You are a friendly Chatbot.", label="System message"), | |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)", | |
), | |
], | |
outputs=[gr.Textbox(label="Prognosis Summary"), gr.Image(label="t-SNE Plot")] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |