AAA1988 commited on
Commit
73a69a7
·
verified ·
1 Parent(s): cdfae42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -30
app.py CHANGED
@@ -17,7 +17,9 @@ embedding_arr = model.encode(df['symptoms']).astype(np.float32)
17
  # ------ Clustering Setup ------
18
  kmeans = MiniBatchKMeans(n_clusters=10, random_state=42)
19
  cluster_labels = kmeans.fit_predict(embedding_arr)
20
- cluster_prognosis_map = df.groupby(cluster_labels)['prognosis'].agg(lambda x: x.mode().tolist())
 
 
21
 
22
  # ------ PCA Initialization ------
23
  pca = PCA(n_components=2).fit(embedding_arr)
@@ -31,7 +33,7 @@ def cached_encode(text):
31
  client = AsyncInferenceClient("HuggingFaceH4/zephyr-7b-beta")
32
 
33
  # ------ Streaming Response Function ------
34
- async def respond(message, history, system_message, max_tokens, temperature, top_p):
35
  try:
36
  # Encoding and clustering
37
  query_embedding = cached_encode(message)
@@ -46,10 +48,10 @@ async def respond(message, history, system_message, max_tokens, temperature, top
46
  "role": "user",
47
  "content": message
48
  }],
49
- max_tokens=max_tokens,
50
  stream=True,
51
- temperature=temperature,
52
- top_p=top_p
53
  )
54
 
55
  full_response = ""
@@ -57,36 +59,25 @@ async def respond(message, history, system_message, max_tokens, temperature, top
57
  content = chunk.choices[0].delta.content
58
  if content:
59
  full_response += content
60
- yield full_response # Stream partial responses
61
 
62
- # Append cluster prognosis after completion
63
- yield f"{full_response}\n\nCluster {query_cluster} common prognoses: {', '.join(cluster_prognosis_map[query_cluster])}"
 
 
 
 
64
 
65
  except Exception as e:
66
  yield f"Error: {str(e)}"
67
 
68
- # ------ Gradio Interface with Example Query Button ------
69
- with gr.Blocks() as demo:
70
- with gr.Row():
71
- chatbot = gr.ChatInterface(
72
- respond,
73
- additional_inputs=[
74
- gr.Textbox(value="Medical diagnosis assistant", label="System Role"),
75
- gr.Slider(512, 2048, value=512, step=128, label="Max Tokens"),
76
- gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
77
- gr.Slider(0.5, 1.0, value=0.95, step=0.05, label="Top-p")
78
- ]
79
- )
80
- with gr.Row():
81
- example_button = gr.Button("Use Example Query")
82
- example_query = gr.Textbox(
83
- value="I have a headache and fever",
84
- label="Example Query",
85
- interactive=False
86
- )
87
-
88
- # Trigger example query manually when button is clicked
89
- example_button.click(fn=lambda: "I have a headache and fever", outputs=[chatbot])
90
 
91
  if __name__ == "__main__":
92
  demo.launch()
 
17
  # ------ Clustering Setup ------
18
  kmeans = MiniBatchKMeans(n_clusters=10, random_state=42)
19
  cluster_labels = kmeans.fit_predict(embedding_arr)
20
+ cluster_prognosis_map = df.groupby(cluster_labels)['prognosis'].agg(
21
+ lambda x: x.value_counts(normalize=True).head(3).to_dict()
22
+ )
23
 
24
  # ------ PCA Initialization ------
25
  pca = PCA(n_components=2).fit(embedding_arr)
 
33
  client = AsyncInferenceClient("HuggingFaceH4/zephyr-7b-beta")
34
 
35
  # ------ Streaming Response Function ------
36
+ async def respond(message, history, system_message):
37
  try:
38
  # Encoding and clustering
39
  query_embedding = cached_encode(message)
 
48
  "role": "user",
49
  "content": message
50
  }],
51
+ max_tokens=512,
52
  stream=True,
53
+ temperature=0.7,
54
+ top_p=0.95
55
  )
56
 
57
  full_response = ""
 
59
  content = chunk.choices[0].delta.content
60
  if content:
61
  full_response += content
 
62
 
63
+ # Format prognosis likelihoods
64
+ prognosis_info = cluster_prognosis_map[query_cluster]
65
+ formatted_prognoses = ", ".join([f"{k}: {v*100:.1f}%" for k, v in prognosis_info.items()])
66
+
67
+ # Append formatted prognosis to response
68
+ yield f"{full_response}\n\nCommon prognoses: {formatted_prognoses}"
69
 
70
  except Exception as e:
71
  yield f"Error: {str(e)}"
72
 
73
+ # ------ Gradio Interface ------
74
+ demo = gr.ChatInterface(
75
+ respond,
76
+ additional_inputs=[
77
+ gr.Textbox(value="Medical diagnosis assistant", label="System Role")
78
+ ],
79
+ examples=[["I have a headache and fever"]] # Example query for guidance
80
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  if __name__ == "__main__":
83
  demo.launch()