AAA1988 commited on
Commit
6725c7c
·
verified ·
1 Parent(s): 70b6bd2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -70
app.py CHANGED
@@ -6,99 +6,65 @@ from sentence_transformers import SentenceTransformer
6
  from functools import lru_cache
7
  from huggingface_hub import AsyncInferenceClient
8
  from sklearn.decomposition import PCA
9
- from sklearn.metrics import pairwise_distances_argmin_min
10
 
11
- # ------ Enhanced Data Loading ------
12
- df = pd.read_csv("symbipredict_2022_filtered.csv").sample(frac=1, random_state=42) # Shuffle data
13
- SYMPTOM_FIELD = 'symptoms'
14
- PROGNOSIS_FIELD = 'prognosis'
15
 
16
- # ------ Optimized Model Initialization ------
17
- model = SentenceTransformer("all-MiniLM-L6-v2", device='cpu')
18
- embedding_arr = model.encode(df[SYMPTOM_FIELD], show_progress_bar=True).astype(np.float32)
19
 
20
- # ------ Robust Clustering Setup ------
21
- kmeans = MiniBatchKMeans(
22
- n_clusters=15, # Increased for better granularity [2][13]
23
- random_state=42,
24
- n_init=5, # Multiple initializations [2][10]
25
- max_iter=300 # Better convergence [2]
26
- )
27
  cluster_labels = kmeans.fit_predict(embedding_arr)
28
- centroids = kmeans.cluster_centers_
29
 
30
- # Cluster validation metrics
31
- cluster_quality = pairwise_distances_argmin_min(embedding_arr, centroids)[1].mean()
32
 
33
- # Prognosis mapping with confidence scores
34
- cluster_prognosis_map = df.groupby(cluster_labels)[PROGNOSIS_FIELD].agg(
35
- lambda x: x.value_counts(normalize=True).head(3).to_dict() # Top 3 prognoses with frequencies
36
- )
37
 
38
- # ------ Session Context Management ------
39
- class DiagnosisSession:
40
- def __init__(self):
41
- self.sessions = {}
42
- self.similarity_threshold = 0.82 # Optimized per [11]
43
-
44
- def get_cluster(self, history, query_embedding):
45
- session_id = hash(frozenset(history))
46
- if session_id in self.sessions:
47
- prev_centroid = centroids[self.sessions[session_id]['cluster']]
48
- distance = np.linalg.norm(query_embedding - prev_centroid)
49
- if distance < self.similarity_threshold:
50
- return self.sessions[session_id]
51
-
52
- new_cluster = kmeans.predict(query_embedding.reshape(1, -1))[0]
53
- self.sessions[session_id] = {
54
- 'cluster': new_cluster,
55
- 'embedding': query_embedding
56
- }
57
- return self.sessions[session_id]
58
-
59
- session_manager = DiagnosisSession()
60
 
61
- # ------ Streaming Response Improvements ------
62
  async def respond(message, history, system_message, max_tokens, temperature, top_p):
63
  try:
 
64
  query_embedding = cached_encode(message)
65
- session = session_manager.get_cluster(history, query_embedding)
66
- cluster_info = cluster_prognosis_map[session['cluster']]
67
 
68
- # Validate cluster quality
69
- if cluster_quality < 0.65: # [10][13]
70
- yield "System confidence low - consult a healthcare professional"
71
- return
72
-
73
  # Generate streaming response
74
  stream = await client.chat_completion(
75
  messages=[{
76
  "role": "system",
77
- f"content": f"{system_message}\nCurrent cluster: {session['cluster']}"
78
  }, {
79
  "role": "user",
80
  "content": message
81
  }],
82
- max_tokens=min(max_tokens, 1024), # Safety limit
83
  stream=True,
84
- temperature=max(0.1, min(temperature, 1.0)), # Constrained randomness
85
  top_p=top_p
86
  )
87
-
88
  full_response = ""
89
  async for chunk in stream:
90
- if chunk.choices[0].delta.content:
91
- full_response += chunk.choices[0].delta.content
92
- yield full_response
93
-
94
- # Format prognosis display
95
- top_diagnoses = [f"{k} ({v:.1%})" for k,v in cluster_info.items()]
96
- yield f"{full_response}\n\nLikely conditions: {', '.join(top_diagnoses)}"
97
-
98
  except Exception as e:
99
- yield f"⚠️ Medical system error: {str(e)}"
100
 
101
- # ------ Enhanced Gradio Interface ------
102
  # ------ Gradio Interface ------
103
  demo = gr.ChatInterface(
104
  respond,
@@ -108,9 +74,7 @@ demo = gr.ChatInterface(
108
  gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
109
  gr.Slider(0.5, 1.0, value=0.95, step=0.05, label="Top-p")
110
  ]
111
- ).queue(concurrency_limit=5) # Updated concurrency handling
112
 
113
  if __name__ == "__main__":
114
- demo.launch(max_threads=10) # Add thread pool configuration
115
-
116
-
 
6
  from functools import lru_cache
7
  from huggingface_hub import AsyncInferenceClient
8
  from sklearn.decomposition import PCA
 
9
 
10
+ # ------ Data Loading ------
11
+ df = pd.read_csv("symbipredict_2022_filtered.csv")
 
 
12
 
13
+ # ------ Model Initialization ------
14
+ model = SentenceTransformer("all-MiniLM-L6-v2")
15
+ embedding_arr = model.encode(df['symptoms']).astype(np.float32)
16
 
17
+ # ------ Clustering Setup ------
18
+ kmeans = MiniBatchKMeans(n_clusters=10, random_state=42)
 
 
 
 
 
19
  cluster_labels = kmeans.fit_predict(embedding_arr)
20
+ cluster_prognosis_map = df.groupby(cluster_labels)['prognosis'].agg(lambda x: x.mode().tolist())
21
 
22
+ # ------ PCA Initialization ------
23
+ pca = PCA(n_components=2).fit(embedding_arr)
24
 
25
+ # ------ Cached Functions ------
26
+ @lru_cache(maxsize=100)
27
+ def cached_encode(text):
28
+ return model.encode(text, convert_to_numpy=True)
29
 
30
+ # ------ Async Inference Client ------
31
+ client = AsyncInferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ # ------ Streaming Response Function ------
34
  async def respond(message, history, system_message, max_tokens, temperature, top_p):
35
  try:
36
+ # Encoding and clustering
37
  query_embedding = cached_encode(message)
38
+ query_cluster = kmeans.predict(query_embedding.reshape(1, -1))[0]
 
39
 
 
 
 
 
 
40
  # Generate streaming response
41
  stream = await client.chat_completion(
42
  messages=[{
43
  "role": "system",
44
+ "content": system_message
45
  }, {
46
  "role": "user",
47
  "content": message
48
  }],
49
+ max_tokens=max_tokens,
50
  stream=True,
51
+ temperature=temperature,
52
  top_p=top_p
53
  )
54
+
55
  full_response = ""
56
  async for chunk in stream:
57
+ content = chunk.choices[0].delta.content
58
+ if content:
59
+ full_response += content
60
+ yield full_response # Stream partial responses
61
+
62
+ # Append cluster prognosis after completion
63
+ yield f"{full_response}\n\nCluster {query_cluster} common prognoses: {', '.join(cluster_prognosis_map[query_cluster])}"
64
+
65
  except Exception as e:
66
+ yield f"Error: {str(e)}"
67
 
 
68
  # ------ Gradio Interface ------
69
  demo = gr.ChatInterface(
70
  respond,
 
74
  gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
75
  gr.Slider(0.5, 1.0, value=0.95, step=0.05, label="Top-p")
76
  ]
77
+ ).queue()
78
 
79
  if __name__ == "__main__":
80
+ demo.launch()