Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -17,7 +17,9 @@ embedding_arr = model.encode(df['symptoms']).astype(np.float32)
|
|
17 |
# ------ Clustering Setup ------
|
18 |
kmeans = MiniBatchKMeans(n_clusters=10, random_state=42)
|
19 |
cluster_labels = kmeans.fit_predict(embedding_arr)
|
20 |
-
cluster_prognosis_map = df.groupby(cluster_labels)['prognosis'].agg(
|
|
|
|
|
21 |
|
22 |
# ------ PCA Initialization ------
|
23 |
pca = PCA(n_components=2).fit(embedding_arr)
|
@@ -31,7 +33,7 @@ def cached_encode(text):
|
|
31 |
client = AsyncInferenceClient("HuggingFaceH4/zephyr-7b-beta")
|
32 |
|
33 |
# ------ Streaming Response Function ------
|
34 |
-
async def respond(message, history, system_message
|
35 |
try:
|
36 |
# Encoding and clustering
|
37 |
query_embedding = cached_encode(message)
|
@@ -46,10 +48,10 @@ async def respond(message, history, system_message, max_tokens, temperature, top
|
|
46 |
"role": "user",
|
47 |
"content": message
|
48 |
}],
|
49 |
-
max_tokens=
|
50 |
stream=True,
|
51 |
-
temperature=
|
52 |
-
top_p=
|
53 |
)
|
54 |
|
55 |
full_response = ""
|
@@ -57,36 +59,25 @@ async def respond(message, history, system_message, max_tokens, temperature, top
|
|
57 |
content = chunk.choices[0].delta.content
|
58 |
if content:
|
59 |
full_response += content
|
60 |
-
yield full_response # Stream partial responses
|
61 |
|
62 |
-
#
|
63 |
-
|
|
|
|
|
|
|
|
|
64 |
|
65 |
except Exception as e:
|
66 |
yield f"Error: {str(e)}"
|
67 |
|
68 |
-
# ------ Gradio Interface
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
|
77 |
-
gr.Slider(0.5, 1.0, value=0.95, step=0.05, label="Top-p")
|
78 |
-
]
|
79 |
-
)
|
80 |
-
with gr.Row():
|
81 |
-
example_button = gr.Button("Use Example Query")
|
82 |
-
example_query = gr.Textbox(
|
83 |
-
value="I have a headache and fever",
|
84 |
-
label="Example Query",
|
85 |
-
interactive=False
|
86 |
-
)
|
87 |
-
|
88 |
-
# Trigger example query manually when button is clicked
|
89 |
-
example_button.click(fn=lambda: "I have a headache and fever", outputs=[chatbot])
|
90 |
|
91 |
if __name__ == "__main__":
|
92 |
demo.launch()
|
|
|
17 |
# ------ Clustering Setup ------
|
18 |
kmeans = MiniBatchKMeans(n_clusters=10, random_state=42)
|
19 |
cluster_labels = kmeans.fit_predict(embedding_arr)
|
20 |
+
cluster_prognosis_map = df.groupby(cluster_labels)['prognosis'].agg(
|
21 |
+
lambda x: x.value_counts(normalize=True).head(3).to_dict()
|
22 |
+
)
|
23 |
|
24 |
# ------ PCA Initialization ------
|
25 |
pca = PCA(n_components=2).fit(embedding_arr)
|
|
|
33 |
client = AsyncInferenceClient("HuggingFaceH4/zephyr-7b-beta")
|
34 |
|
35 |
# ------ Streaming Response Function ------
|
36 |
+
async def respond(message, history, system_message):
|
37 |
try:
|
38 |
# Encoding and clustering
|
39 |
query_embedding = cached_encode(message)
|
|
|
48 |
"role": "user",
|
49 |
"content": message
|
50 |
}],
|
51 |
+
max_tokens=512,
|
52 |
stream=True,
|
53 |
+
temperature=0.7,
|
54 |
+
top_p=0.95
|
55 |
)
|
56 |
|
57 |
full_response = ""
|
|
|
59 |
content = chunk.choices[0].delta.content
|
60 |
if content:
|
61 |
full_response += content
|
|
|
62 |
|
63 |
+
# Format prognosis likelihoods
|
64 |
+
prognosis_info = cluster_prognosis_map[query_cluster]
|
65 |
+
formatted_prognoses = ", ".join([f"{k}: {v*100:.1f}%" for k, v in prognosis_info.items()])
|
66 |
+
|
67 |
+
# Append formatted prognosis to response
|
68 |
+
yield f"{full_response}\n\nCommon prognoses: {formatted_prognoses}"
|
69 |
|
70 |
except Exception as e:
|
71 |
yield f"Error: {str(e)}"
|
72 |
|
73 |
+
# ------ Gradio Interface ------
|
74 |
+
demo = gr.ChatInterface(
|
75 |
+
respond,
|
76 |
+
additional_inputs=[
|
77 |
+
gr.Textbox(value="Medical diagnosis assistant", label="System Role")
|
78 |
+
],
|
79 |
+
examples=[["I have a headache and fever"]] # Example query for guidance
|
80 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
if __name__ == "__main__":
|
83 |
demo.launch()
|