AAA1988 commited on
Commit
6bb7a87
·
verified ·
1 Parent(s): a2c3301

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -45
app.py CHANGED
@@ -6,7 +6,6 @@ from sentence_transformers import SentenceTransformer
6
  from functools import lru_cache
7
  from huggingface_hub import AsyncInferenceClient
8
  from sklearn.decomposition import PCA
9
- from sklearn.metrics import pairwise_distances_argmin_min
10
 
11
  # ------ Data Loading ------
12
  df = pd.read_csv("symbipredict_2022_filtered.csv")
@@ -15,40 +14,10 @@ df = pd.read_csv("symbipredict_2022_filtered.csv")
15
  model = SentenceTransformer("all-MiniLM-L6-v2")
16
  embedding_arr = model.encode(df['symptoms']).astype(np.float32)
17
 
18
- # ------ Enhanced Clustering Setup ------
19
- # Ensemble clustering with multiple initializations
20
- kmeans_ensemble = [MiniBatchKMeans(n_clusters=10, random_state=i).fit(embedding_arr)
21
- for i in range(5)]
22
- cluster_labels = np.array([model.predict(embedding_arr) for model in kmeans_ensemble])
23
- final_labels = np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=0, arr=cluster_labels)
24
-
25
- # Cluster validation
26
- cluster_stability = {}
27
- for cluster_id in np.unique(final_labels):
28
- mask = final_labels == cluster_id
29
- stability_score = np.mean([np.sum(cluster_labels[i][mask] == cluster_id)/np.sum(mask)
30
- for i in range(5)])
31
- cluster_stability[cluster_id] = stability_score
32
-
33
- cluster_prognosis_map = df.groupby(final_labels)['prognosis'].agg(lambda x: x.mode().tolist())
34
-
35
- # ------ Session Context Tracking ------
36
- class SessionManager:
37
- def __init__(self):
38
- self.sessions = {}
39
-
40
- def get_cluster(self, session_id, query_embedding, threshold=0.85):
41
- if session_id in self.sessions:
42
- prev_centroid = kmeans_ensemble[0].cluster_centers_[self.sessions[session_id]]
43
- distance = np.linalg.norm(query_embedding - prev_centroid)
44
- if distance < threshold:
45
- return self.sessions[session_id]
46
-
47
- new_cluster = kmeans_ensemble[0].predict(query_embedding.reshape(1, -1))[0]
48
- self.sessions[session_id] = new_cluster
49
- return new_cluster
50
-
51
- session_mgr = SessionManager()
52
 
53
  # ------ PCA Initialization ------
54
  pca = PCA(n_components=2).fit(embedding_arr)
@@ -64,17 +33,10 @@ client = AsyncInferenceClient("HuggingFaceH4/zephyr-7b-beta")
64
  # ------ Streaming Response Function ------
65
  async def respond(message, history, system_message, max_tokens, temperature, top_p):
66
  try:
67
- session_id = hash(frozenset(history))
68
  query_embedding = cached_encode(message)
 
69
 
70
- # Get cluster with context awareness
71
- query_cluster = session_mgr.get_cluster(session_id, query_embedding)
72
-
73
- # Validate cluster stability
74
- if cluster_stability[query_cluster] < 0.7:
75
- yield "Low diagnostic confidence - please consult a healthcare professional"
76
- return
77
-
78
  # Generate streaming response
79
  stream = await client.chat_completion(
80
  messages=[{
@@ -95,7 +57,7 @@ async def respond(message, history, system_message, max_tokens, temperature, top
95
  content = chunk.choices[0].delta.content
96
  if content:
97
  full_response += content
98
- yield full_response
99
 
100
  # Append cluster prognosis after completion
101
  yield f"{full_response}\n\nCluster {query_cluster} common prognoses: {', '.join(cluster_prognosis_map[query_cluster])}"
 
6
  from functools import lru_cache
7
  from huggingface_hub import AsyncInferenceClient
8
  from sklearn.decomposition import PCA
 
9
 
10
  # ------ Data Loading ------
11
  df = pd.read_csv("symbipredict_2022_filtered.csv")
 
14
  model = SentenceTransformer("all-MiniLM-L6-v2")
15
  embedding_arr = model.encode(df['symptoms']).astype(np.float32)
16
 
17
+ # ------ Clustering Setup ------
18
+ kmeans = MiniBatchKMeans(n_clusters=10, random_state=42)
19
+ cluster_labels = kmeans.fit_predict(embedding_arr)
20
+ cluster_prognosis_map = df.groupby(cluster_labels)['prognosis'].agg(lambda x: x.mode().tolist())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # ------ PCA Initialization ------
23
  pca = PCA(n_components=2).fit(embedding_arr)
 
33
  # ------ Streaming Response Function ------
34
  async def respond(message, history, system_message, max_tokens, temperature, top_p):
35
  try:
36
+ # Encoding and clustering
37
  query_embedding = cached_encode(message)
38
+ query_cluster = kmeans.predict(query_embedding.reshape(1, -1))[0]
39
 
 
 
 
 
 
 
 
 
40
  # Generate streaming response
41
  stream = await client.chat_completion(
42
  messages=[{
 
57
  content = chunk.choices[0].delta.content
58
  if content:
59
  full_response += content
60
+ yield full_response # Stream partial responses
61
 
62
  # Append cluster prognosis after completion
63
  yield f"{full_response}\n\nCluster {query_cluster} common prognoses: {', '.join(cluster_prognosis_map[query_cluster])}"