AAA1988 commited on
Commit
c03c66e
·
verified ·
1 Parent(s): 0c88f3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -27
app.py CHANGED
@@ -1,32 +1,97 @@
1
  import gradio as gr
 
 
 
 
 
 
 
2
  from huggingface_hub import InferenceClient
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
27
 
28
- response = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
 
30
  for message in client.chat_completion(
31
  messages,
32
  max_tokens=max_tokens,
@@ -35,13 +100,14 @@ def respond(
35
  top_p=top_p,
36
  ):
37
  token = message.choices[0].delta.content
 
38
 
39
- response += token
40
- yield response
41
 
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
  demo = gr.ChatInterface(
46
  respond,
47
  additional_inputs=[
@@ -56,8 +122,8 @@ demo = gr.ChatInterface(
56
  label="Top-p (nucleus sampling)",
57
  ),
58
  ],
 
59
  )
60
 
61
-
62
  if __name__ == "__main__":
63
- demo.launch()
 
1
  import gradio as gr
2
+ import numpy as np
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ from sklearn.manifold import TSNE
6
+ from sklearn.cluster import KMeans
7
+ from sentence_transformers import SentenceTransformer
8
+ import io
9
  from huggingface_hub import InferenceClient
10
 
11
+ # Load the dataset
12
+ file_path = 'symbipredict_2022_filtered.csv' # Ensure this file is uploaded to the Space
13
+ df = pd.read_csv(file_path)
 
14
 
15
+ # Load the model from the local directory
16
+ model_path = "all-MiniLM-L6-v2" # Ensure this directory is uploaded to the Space
17
+ model = SentenceTransformer(model_path)
18
 
19
+ # Embed vectors
20
+ embedding_arr = model.encode(df['symptoms'])
 
 
 
 
 
 
 
21
 
22
+ # Apply K-Means with the optimal number of clusters (41 clusters)
23
+ optimal_n_clusters = 41
24
+ kmeans = KMeans(n_clusters=optimal_n_clusters, random_state=42)
25
+ kmeans_labels = kmeans.fit_predict(embedding_arr)
 
26
 
27
+ # Create a DataFrame with prognosis and their corresponding clusters
28
+ cluster_prognosis_mapping = pd.DataFrame({'prognosis': df['prognosis'], 'cluster': kmeans_labels})
29
+
30
+ # Get the unique cluster-prognosis pairs
31
+ unique_clusters = cluster_prognosis_mapping.drop_duplicates().sort_values(by='cluster')
32
+
33
+ # Initialize the Inference Client
34
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
35
 
36
+ def respond(message, history, system_message, max_tokens, temperature, top_p):
37
+ query_embedding = model.encode([message])[0]
38
+
39
+ # Combine embeddings with the query embedding
40
+ combined_embeddings = np.vstack([embedding_arr, query_embedding])
41
+
42
+ # Apply t-SNE to the combined embeddings
43
+ tsne = TSNE(n_components=2, perplexity=30, n_iter=1000, random_state=42)
44
+ embedding_tsne = tsne.fit_transform(combined_embeddings)
45
+
46
+ # Separate the transformed query embedding from the rest
47
+ embedding_tsne_query = embedding_tsne[-1]
48
+ embedding_tsne = embedding_tsne[:-1]
49
+
50
+ # Plot data along t-SNE components with the query
51
+ plt.figure(figsize=(14, 10))
52
+ plt.rcParams.update({'font.size': 16})
53
+ plt.grid()
54
+
55
+ # Use a colormap for different clusters
56
+ cmap = plt.get_cmap('tab20', optimal_n_clusters)
57
+
58
+ # Highlight the cluster to which the query embedding belongs
59
+ query_cluster = kmeans.predict(query_embedding.reshape(1, -1))[0]
60
+ highlight_cluster = query_cluster
61
+
62
+ c = 0
63
+ for prognosis in df['prognosis'].unique():
64
+ idx = np.where(df['prognosis'] == prognosis)
65
+ if kmeans.predict(embedding_arr[idx])[0] == highlight_cluster:
66
+ plt.scatter(embedding_tsne[idx, 0], embedding_tsne[idx, 1], c=[cmap(c)] * len(idx[0]), edgecolor='black', linewidth=1, label=f'{prognosis} (Cluster {highlight_cluster})')
67
+ else:
68
+ plt.scatter(embedding_tsne[idx, 0], embedding_tsne[idx, 1], c=[cmap(c)] * len(idx[0]), label=prognosis)
69
+ c = c + 1 / len(df['prognosis'].unique())
70
+
71
+ plt.scatter(embedding_tsne_query[0], embedding_tsne_query[1], c='k', marker='D', s=200, label='query')
72
+
73
+ plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
74
+ plt.xticks(rotation=45)
75
+ plt.xlabel("t-SNE Component 1")
76
+ plt.ylabel("t-SNE Component 2")
77
+ plt.title(f'Query: "{message}" (Belongs to Cluster {highlight_cluster})')
78
+
79
+ # Save the plot to a bytes buffer
80
+ buf = io.BytesIO()
81
+ plt.savefig(buf, format='png')
82
+ buf.seek(0)
83
+ plt.close()
84
+
85
+ # Generate the text response using the Inference Client
86
+ messages = [{"role": "system", "content": system_message}]
87
+ for user_msg, bot_msg in history:
88
+ if user_msg:
89
+ messages.append({"role": "user", "content": user_msg})
90
+ if bot_msg:
91
+ messages.append({"role": "assistant", "content": bot_msg})
92
+ messages.append({"role": "user", "content": message})
93
 
94
+ response_text = ""
95
  for message in client.chat_completion(
96
  messages,
97
  max_tokens=max_tokens,
 
100
  top_p=top_p,
101
  ):
102
  token = message.choices[0].delta.content
103
+ response_text += token
104
 
105
+ prognosis_summary = unique_clusters[unique_clusters['cluster'] == highlight_cluster]['prognosis'].tolist()
106
+ response_text += f"\nThe query belongs to cluster {highlight_cluster} which includes the following prognosis: {', '.join(prognosis_summary)}."
107
 
108
+ return {"response": response_text, "image": buf}
109
+
110
+ # Set up the Gradio Chat Interface
111
  demo = gr.ChatInterface(
112
  respond,
113
  additional_inputs=[
 
122
  label="Top-p (nucleus sampling)",
123
  ),
124
  ],
125
+ outputs=[gr.Textbox(label="Prognosis Summary"), gr.Image(label="t-SNE Plot")]
126
  )
127
 
 
128
  if __name__ == "__main__":
129
+ demo.launch()