Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,6 @@ from sentence_transformers import SentenceTransformer
|
|
6 |
from functools import lru_cache
|
7 |
from huggingface_hub import AsyncInferenceClient
|
8 |
from sklearn.decomposition import PCA
|
9 |
-
from sklearn.metrics import pairwise_distances_argmin_min
|
10 |
|
11 |
# ------ Data Loading ------
|
12 |
df = pd.read_csv("symbipredict_2022_filtered.csv")
|
@@ -15,40 +14,10 @@ df = pd.read_csv("symbipredict_2022_filtered.csv")
|
|
15 |
model = SentenceTransformer("all-MiniLM-L6-v2")
|
16 |
embedding_arr = model.encode(df['symptoms']).astype(np.float32)
|
17 |
|
18 |
-
# ------
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
cluster_labels = np.array([model.predict(embedding_arr) for model in kmeans_ensemble])
|
23 |
-
final_labels = np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=0, arr=cluster_labels)
|
24 |
-
|
25 |
-
# Cluster validation
|
26 |
-
cluster_stability = {}
|
27 |
-
for cluster_id in np.unique(final_labels):
|
28 |
-
mask = final_labels == cluster_id
|
29 |
-
stability_score = np.mean([np.sum(cluster_labels[i][mask] == cluster_id)/np.sum(mask)
|
30 |
-
for i in range(5)])
|
31 |
-
cluster_stability[cluster_id] = stability_score
|
32 |
-
|
33 |
-
cluster_prognosis_map = df.groupby(final_labels)['prognosis'].agg(lambda x: x.mode().tolist())
|
34 |
-
|
35 |
-
# ------ Session Context Tracking ------
|
36 |
-
class SessionManager:
|
37 |
-
def __init__(self):
|
38 |
-
self.sessions = {}
|
39 |
-
|
40 |
-
def get_cluster(self, session_id, query_embedding, threshold=0.85):
|
41 |
-
if session_id in self.sessions:
|
42 |
-
prev_centroid = kmeans_ensemble[0].cluster_centers_[self.sessions[session_id]]
|
43 |
-
distance = np.linalg.norm(query_embedding - prev_centroid)
|
44 |
-
if distance < threshold:
|
45 |
-
return self.sessions[session_id]
|
46 |
-
|
47 |
-
new_cluster = kmeans_ensemble[0].predict(query_embedding.reshape(1, -1))[0]
|
48 |
-
self.sessions[session_id] = new_cluster
|
49 |
-
return new_cluster
|
50 |
-
|
51 |
-
session_mgr = SessionManager()
|
52 |
|
53 |
# ------ PCA Initialization ------
|
54 |
pca = PCA(n_components=2).fit(embedding_arr)
|
@@ -64,17 +33,10 @@ client = AsyncInferenceClient("HuggingFaceH4/zephyr-7b-beta")
|
|
64 |
# ------ Streaming Response Function ------
|
65 |
async def respond(message, history, system_message, max_tokens, temperature, top_p):
|
66 |
try:
|
67 |
-
|
68 |
query_embedding = cached_encode(message)
|
|
|
69 |
|
70 |
-
# Get cluster with context awareness
|
71 |
-
query_cluster = session_mgr.get_cluster(session_id, query_embedding)
|
72 |
-
|
73 |
-
# Validate cluster stability
|
74 |
-
if cluster_stability[query_cluster] < 0.7:
|
75 |
-
yield "Low diagnostic confidence - please consult a healthcare professional"
|
76 |
-
return
|
77 |
-
|
78 |
# Generate streaming response
|
79 |
stream = await client.chat_completion(
|
80 |
messages=[{
|
@@ -95,7 +57,7 @@ async def respond(message, history, system_message, max_tokens, temperature, top
|
|
95 |
content = chunk.choices[0].delta.content
|
96 |
if content:
|
97 |
full_response += content
|
98 |
-
yield full_response
|
99 |
|
100 |
# Append cluster prognosis after completion
|
101 |
yield f"{full_response}\n\nCluster {query_cluster} common prognoses: {', '.join(cluster_prognosis_map[query_cluster])}"
|
|
|
6 |
from functools import lru_cache
|
7 |
from huggingface_hub import AsyncInferenceClient
|
8 |
from sklearn.decomposition import PCA
|
|
|
9 |
|
10 |
# ------ Data Loading ------
|
11 |
df = pd.read_csv("symbipredict_2022_filtered.csv")
|
|
|
14 |
model = SentenceTransformer("all-MiniLM-L6-v2")
|
15 |
embedding_arr = model.encode(df['symptoms']).astype(np.float32)
|
16 |
|
17 |
+
# ------ Clustering Setup ------
|
18 |
+
kmeans = MiniBatchKMeans(n_clusters=10, random_state=42)
|
19 |
+
cluster_labels = kmeans.fit_predict(embedding_arr)
|
20 |
+
cluster_prognosis_map = df.groupby(cluster_labels)['prognosis'].agg(lambda x: x.mode().tolist())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
# ------ PCA Initialization ------
|
23 |
pca = PCA(n_components=2).fit(embedding_arr)
|
|
|
33 |
# ------ Streaming Response Function ------
|
34 |
async def respond(message, history, system_message, max_tokens, temperature, top_p):
|
35 |
try:
|
36 |
+
# Encoding and clustering
|
37 |
query_embedding = cached_encode(message)
|
38 |
+
query_cluster = kmeans.predict(query_embedding.reshape(1, -1))[0]
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
# Generate streaming response
|
41 |
stream = await client.chat_completion(
|
42 |
messages=[{
|
|
|
57 |
content = chunk.choices[0].delta.content
|
58 |
if content:
|
59 |
full_response += content
|
60 |
+
yield full_response # Stream partial responses
|
61 |
|
62 |
# Append cluster prognosis after completion
|
63 |
yield f"{full_response}\n\nCluster {query_cluster} common prognoses: {', '.join(cluster_prognosis_map[query_cluster])}"
|