AAA1988 commited on
Commit
cd072df
·
verified ·
1 Parent(s): 619550d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -95
app.py CHANGED
@@ -1,126 +1,115 @@
1
  import gradio as gr
2
  import numpy as np
3
  import pandas as pd
 
 
4
  import matplotlib.pyplot as plt
5
  from sklearn.manifold import TSNE
6
- from sklearn.cluster import KMeans
7
  from sentence_transformers import SentenceTransformer
 
 
 
8
  import io
9
  from huggingface_hub import InferenceClient
10
 
11
- # Load the dataset
12
- file_path = 'symbipredict_2022_filtered.csv' # Ensure this file is uploaded to the Space
13
  df = pd.read_csv(file_path)
 
 
14
 
15
- # Load the model from the local directory
16
- model_path = "all-MiniLM-L6-v2" # Ensure this directory is uploaded to the Space
17
- model = SentenceTransformer(model_path)
18
-
19
- # Embed vectors
20
- embedding_arr = model.encode(df['symptoms'])
21
-
22
- # Apply K-Means with the optimal number of clusters (41 clusters)
23
  optimal_n_clusters = 41
24
- kmeans = KMeans(n_clusters=optimal_n_clusters, random_state=42)
 
 
25
  kmeans_labels = kmeans.fit_predict(embedding_arr)
26
 
27
- # Create a DataFrame with prognosis and their corresponding clusters
28
- cluster_prognosis_mapping = pd.DataFrame({'prognosis': df['prognosis'], 'cluster': kmeans_labels})
 
29
 
30
- # Get the unique cluster-prognosis pairs
31
- unique_clusters = cluster_prognosis_mapping.drop_duplicates().sort_values(by='cluster')
 
32
 
33
- # Initialize the Inference Client
34
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
35
 
36
- def respond(message, history, system_message, max_tokens, temperature, top_p):
37
- query_embedding = model.encode([message])[0]
38
-
39
- # Combine embeddings with the query embedding
40
- combined_embeddings = np.vstack([embedding_arr, query_embedding])
41
-
42
- # Apply t-SNE to the combined embeddings
43
- tsne = TSNE(n_components=2, perplexity=30, n_iter=1000, random_state=42)
44
- embedding_tsne = tsne.fit_transform(combined_embeddings)
45
-
46
- # Separate the transformed query embedding from the rest
47
- embedding_tsne_query = embedding_tsne[-1]
48
- embedding_tsne = embedding_tsne[:-1]
49
-
50
- # Plot data along t-SNE components with the query
51
- plt.figure(figsize=(14, 10))
52
- plt.rcParams.update({'font.size': 16})
53
- plt.grid()
54
-
55
- # Use a colormap for different clusters
56
- cmap = plt.get_cmap('tab20', optimal_n_clusters)
57
-
58
- # Highlight the cluster to which the query embedding belongs
59
- query_cluster = kmeans.predict(query_embedding.reshape(1, -1))[0]
60
- highlight_cluster = query_cluster
61
-
62
- c = 0
63
- for prognosis in df['prognosis'].unique():
64
- idx = np.where(df['prognosis'] == prognosis)
65
- if kmeans.predict(embedding_arr[idx])[0] == highlight_cluster:
66
- plt.scatter(embedding_tsne[idx, 0], embedding_tsne[idx, 1], c=[cmap(c)] * len(idx[0]), edgecolor='black', linewidth=1, label=f'{prognosis} (Cluster {highlight_cluster})')
67
- else:
68
- plt.scatter(embedding_tsne[idx, 0], embedding_tsne[idx, 1], c=[cmap(c)] * len(idx[0]), label=prognosis)
69
- c = c + 1 / len(df['prognosis'].unique())
70
 
71
- plt.scatter(embedding_tsne_query[0], embedding_tsne_query[1], c='k', marker='D', s=200, label='query')
 
 
 
72
 
73
- plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
74
- plt.xticks(rotation=45)
75
- plt.xlabel("t-SNE Component 1")
76
- plt.ylabel("t-SNE Component 2")
77
- plt.title(f'Query: "{message}" (Belongs to Cluster {highlight_cluster})')
78
 
79
- # Save the plot to a bytes buffer
 
 
80
  buf = io.BytesIO()
81
- plt.savefig(buf, format='png')
82
  buf.seek(0)
83
- plt.close()
84
-
85
- # Generate the text response using the Inference Client
86
- messages = [{"role": "system", "content": system_message}]
87
- for user_msg, bot_msg in history:
88
- if user_msg:
89
- messages.append({"role": "user", "content": user_msg})
90
- if bot_msg:
91
- messages.append({"role": "assistant", "content": bot_msg})
92
- messages.append({"role": "user", "content": message})
93
-
94
- response_text = ""
95
- for message in client.chat_completion(
96
- messages,
97
- max_tokens=max_tokens,
98
- stream=True,
99
- temperature=temperature,
100
- top_p=top_p,
101
- ):
102
- token = message.choices[0].delta.content
103
- response_text += token
104
 
105
- prognosis_summary = unique_clusters[unique_clusters['cluster'] == highlight_cluster]['prognosis'].tolist()
106
- response_text += f"\nThe query belongs to cluster {highlight_cluster} which includes the following prognosis: {', '.join(prognosis_summary)}."
107
 
108
- return response_text, buf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
- # Set up the Gradio Chat Interface
111
  demo = gr.ChatInterface(
112
  respond,
113
  additional_inputs=[
114
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
115
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
116
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
117
- gr.Slider(
118
- minimum=0.1,
119
- maximum=1.0,
120
- value=0.95,
121
- step=0.05,
122
- label="Top-p (nucleus sampling)",
123
- ),
124
  ]
125
  )
126
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import pandas as pd
4
+ import matplotlib
5
+ matplotlib.use('Agg')
6
  import matplotlib.pyplot as plt
7
  from sklearn.manifold import TSNE
8
+ from sklearn.cluster import MiniBatchKMeans
9
  from sentence_transformers import SentenceTransformer
10
+ from umap import UMAP
11
+ from joblib import Parallel, delayed
12
+ from functools import lru_cache
13
  import io
14
  from huggingface_hub import InferenceClient
15
 
16
+ # ---- Precomputed Elements ----
17
+ file_path = 'symbipredict_2022_filtered.csv'
18
  df = pd.read_csv(file_path)
19
+ model = SentenceTransformer("all-MiniLM-L6-v2")
20
+ embedding_arr = model.encode(df['symptoms']).astype(np.float32)
21
 
22
+ # Clustering
 
 
 
 
 
 
 
23
  optimal_n_clusters = 41
24
+ kmeans = MiniBatchKMeans(n_clusters=optimal_n_clusters,
25
+ batch_size=1024,
26
+ random_state=42)
27
  kmeans_labels = kmeans.fit_predict(embedding_arr)
28
 
29
+ # Dimensionality Reduction
30
+ umap = UMAP(n_components=2, random_state=42)
31
+ embedding_umap = umap.fit_transform(embedding_arr)
32
 
33
+ # Precomputed Mappings
34
+ cluster_prognosis_map = (df.groupby('cluster')['prognosis']
35
+ .unique().to_dict())
36
 
37
+ # ---- Cached Functions ----
38
+ @lru_cache(maxsize=100)
39
+ def cached_encode(text):
40
+ return model.encode([text], convert_to_numpy=True)[0]
41
 
42
+ # ---- Optimized Plotting ----
43
+ fig = plt.figure(figsize=(14, 10))
44
+ ax = fig.add_subplot(111)
45
+ cmap = plt.get_cmap('tab20', optimal_n_clusters)
46
+
47
+ def create_plot(message, query_embedding):
48
+ ax.clear()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ # Plot all points
51
+ ax.scatter(embedding_umap[:, 0], embedding_umap[:, 1],
52
+ c=kmeans_labels, cmap=cmap,
53
+ edgecolor='k', linewidth=0.5, alpha=0.7)
54
 
55
+ # Plot query
56
+ query_umap = umap.transform(query_embedding.reshape(1, -1))
57
+ ax.scatter(query_umap[:, 0], query_umap[:, 1],
58
+ c='red', marker='X', s=200,
59
+ label=f'Query: {message}')
60
 
61
+ # Finalize plot
62
+ ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
63
+ ax.set_title("Medical Condition Clustering")
64
  buf = io.BytesIO()
65
+ plt.savefig(buf, format='png', bbox_inches='tight')
66
  buf.seek(0)
67
+ return buf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ # ---- Response Function ----
70
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
71
 
72
+ async def respond(message, history, system_message, max_tokens, temperature, top_p):
73
+ try:
74
+ # Encoding
75
+ query_embedding = cached_encode(message)
76
+
77
+ # Cluster prediction
78
+ query_cluster = kmeans.predict(query_embedding.reshape(1, -1))[0]
79
+
80
+ # Parallel plot generation
81
+ plot_buf = await asyncio.to_thread(
82
+ create_plot, message, query_embedding
83
+ )
84
+
85
+ # Async LLM response
86
+ llm_response = await client.chat_completion(
87
+ [{"role": "user", "content": message}],
88
+ max_tokens=max_tokens,
89
+ stream=False,
90
+ temperature=temperature,
91
+ top_p=top_p
92
+ )
93
+
94
+ # Combine responses
95
+ full_response = (
96
+ f"{llm_response}\n\nCluster {query_cluster} contains: "
97
+ f"{', '.join(cluster_prognosis_map[query_cluster])}"
98
+ )
99
+
100
+ return full_response, plot_buf
101
+
102
+ except Exception as e:
103
+ return f"Error: {str(e)}", None
104
 
105
+ # ---- Gradio Interface ----
106
  demo = gr.ChatInterface(
107
  respond,
108
  additional_inputs=[
109
+ gr.Textbox(value="Medical diagnosis assistant", label="System Role"),
110
+ gr.Slider(512, 2048, value=512, step=128, label="Max Tokens"),
111
+ gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
112
+ gr.Slider(0.5, 1.0, value=0.95, step=0.05, label="Top-p")
 
 
 
 
 
 
113
  ]
114
  )
115