File size: 3,646 Bytes
ee89cd3
c03c66e
 
cd072df
 
c03c66e
 
cd072df
c03c66e
cd072df
 
 
c03c66e
ee89cd3
 
cd072df
 
c03c66e
cd072df
 
ee89cd3
cd072df
c03c66e
cd072df
 
 
c03c66e
ee89cd3
cd072df
 
 
c03c66e
cd072df
 
 
c03c66e
cd072df
 
 
 
ee89cd3
cd072df
 
 
 
 
 
 
c03c66e
cd072df
 
 
 
c03c66e
cd072df
 
 
 
 
c03c66e
cd072df
 
 
c03c66e
cd072df
c03c66e
cd072df
ee89cd3
cd072df
 
ee89cd3
cd072df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c03c66e
cd072df
ee89cd3
 
 
cd072df
 
 
 
7b8e63b
ee89cd3
 
 
c03c66e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.cluster import MiniBatchKMeans
from sentence_transformers import SentenceTransformer
from umap import UMAP
from joblib import Parallel, delayed
from functools import lru_cache
import io
from huggingface_hub import InferenceClient

# ---- Precomputed Elements ----
file_path = 'symbipredict_2022_filtered.csv'
df = pd.read_csv(file_path)
model = SentenceTransformer("all-MiniLM-L6-v2")
embedding_arr = model.encode(df['symptoms']).astype(np.float32)

# Clustering
optimal_n_clusters = 41
kmeans = MiniBatchKMeans(n_clusters=optimal_n_clusters, 
                        batch_size=1024, 
                        random_state=42)
kmeans_labels = kmeans.fit_predict(embedding_arr)

# Dimensionality Reduction
umap = UMAP(n_components=2, random_state=42)
embedding_umap = umap.fit_transform(embedding_arr)

# Precomputed Mappings
cluster_prognosis_map = (df.groupby('cluster')['prognosis']
                         .unique().to_dict())

# ---- Cached Functions ----
@lru_cache(maxsize=100)
def cached_encode(text):
    return model.encode([text], convert_to_numpy=True)[0]

# ---- Optimized Plotting ----
fig = plt.figure(figsize=(14, 10))
ax = fig.add_subplot(111)
cmap = plt.get_cmap('tab20', optimal_n_clusters)

def create_plot(message, query_embedding):
    ax.clear()
    
    # Plot all points
    ax.scatter(embedding_umap[:, 0], embedding_umap[:, 1],
               c=kmeans_labels, cmap=cmap, 
               edgecolor='k', linewidth=0.5, alpha=0.7)
    
    # Plot query
    query_umap = umap.transform(query_embedding.reshape(1, -1))
    ax.scatter(query_umap[:, 0], query_umap[:, 1],
               c='red', marker='X', s=200, 
               label=f'Query: {message}')
    
    # Finalize plot
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax.set_title("Medical Condition Clustering")
    buf = io.BytesIO()
    plt.savefig(buf, format='png', bbox_inches='tight')
    buf.seek(0)
    return buf

# ---- Response Function ----
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")

async def respond(message, history, system_message, max_tokens, temperature, top_p):
    try:
        # Encoding
        query_embedding = cached_encode(message)
        
        # Cluster prediction
        query_cluster = kmeans.predict(query_embedding.reshape(1, -1))[0]
        
        # Parallel plot generation
        plot_buf = await asyncio.to_thread(
            create_plot, message, query_embedding
        )
        
        # Async LLM response
        llm_response = await client.chat_completion(
            [{"role": "user", "content": message}],
            max_tokens=max_tokens,
            stream=False,
            temperature=temperature,
            top_p=top_p
        )
        
        # Combine responses
        full_response = (
            f"{llm_response}\n\nCluster {query_cluster} contains: "
            f"{', '.join(cluster_prognosis_map[query_cluster])}"
        )
        
        return full_response, plot_buf
        
    except Exception as e:
        return f"Error: {str(e)}", None

# ---- Gradio Interface ----
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="Medical diagnosis assistant", label="System Role"),
        gr.Slider(512, 2048, value=512, step=128, label="Max Tokens"),
        gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(0.5, 1.0, value=0.95, step=0.05, label="Top-p")
    ]
)

if __name__ == "__main__":
    demo.launch()