File size: 8,260 Bytes
385bd53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import os
import streamlit as st
import numpy as np
import faiss
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModel
from groq import Groq

# Load API Key from Environment
groq_api_key = os.environ.get("GROQ_API_KEY")
if groq_api_key is None:
    st.error("GROQ_API_KEY environment variable not set.")
    st.stop()

# Initialize Groq Client
try:
    client = Groq(api_key=groq_api_key)
except Exception as e:
    st.error(f"Error initializing Groq client: {e}")
    st.stop()

# Load PubMedBERT Model (Try Groq API first, then Hugging Face)
try:
    pubmedbert_tokenizer = AutoTokenizer.from_pretrained("NeuML/pubmedbert-base-embeddings")
    pubmedbert_model = AutoModel.from_pretrained("NeuML/pubmedbert-base-embeddings")
    pubmedbert_pipeline = pipeline('feature-extraction', model=pubmedbert_model, tokenizer=pubmedbert_tokenizer, device=-1)
except Exception:
    st.warning("Error loading PubMedBERT from Groq API. Using Hugging Face model.")
    pubmedbert_tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
    pubmedbert_model = AutoModelForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
    pubmedbert_pipeline = pipeline('feature-extraction', model=pubmedbert_model, tokenizer=pubmedbert_tokenizer, device=-1)

# Initialize FAISS Index
embedding_dim = 768
index = faiss.IndexFlatL2(embedding_dim)

# Function to Check if Query is Related to Epilepsy
def preprocess_query(query):
    tokens = query.lower().split()
    epilepsy_keywords = ["seizure", "epilepsy", "convulsion", "neurology", "brain activity"]

    is_epilepsy_related = any(k in tokens for k in epilepsy_keywords)

    return tokens, is_epilepsy_related

# Function to Generate Response with Chat History
def generate_response(user_query, chat_history):
    # Grammatical Correction using LLaMA (Hidden from User)
    try:
        correction_prompt = f"""
        Correct the following user query for grammar and spelling errors, but keep the original intent intact.
        Do not add or remove any information, just fix the grammar.
        User Query: {user_query}
        Corrected Query:
        """
        grammar_completion = client.chat.completions.create(
            messages=[{"role": "user", "content": correction_prompt}],
            model="llama-3.3-70b-versatile",
            stream=False,
        )
        corrected_query = grammar_completion.choices[0].message.content.strip()
        # If correction fails or returns empty, use original query
        if not corrected_query:
            corrected_query = user_query
    except Exception as e:
        corrected_query = user_query # Fallback to original query if correction fails
        print(f"โš ๏ธ Grammar correction error: {e}") # Optional: Log the error for debugging

    tokens, is_epilepsy_related = preprocess_query(corrected_query) # Use corrected query for processing

    # Greeting Responses
    greetings = ["hello", "hi", "hey"]
    if any(word in tokens for word in greetings):
        return "๐Ÿ‘‹ Hello! How can I assist you today?"

    # If Epilepsy Related - Use Epilepsy Focused Response
    if is_epilepsy_related:
        # Try Getting Medical Insights from PubMedBERT
        try:
            pubmedbert_embeddings = pubmedbert_pipeline(corrected_query) # Use corrected query for PubMedBERT
            embedding_mean = np.mean(pubmedbert_embeddings[0], axis=0)
            index.add(np.array([embedding_mean]))
            pubmedbert_insights = "**PubMedBERT Analysis:** โœ… Query is relevant to epilepsy research."
        except Exception as e:
            pubmedbert_insights = f"โš ๏ธ Error during PubMedBERT analysis: {e}"

        # Use LLaMA for Final Response Generation with Chat History Context (Epilepsy Focus)
        try:
            prompt_history = ""
            if chat_history:
                prompt_history += "**Chat History:**\n"
                for message in chat_history:
                    prompt_history += f"{message['role'].capitalize()}: {message['content']}\n"
                prompt_history += "\n"

            epilepsy_prompt = f"""
            {prompt_history}
            **User Query:** {corrected_query} # Use corrected query for final response generation
            **Instructions:** Provide a concise, structured, and human-friendly response specifically about epilepsy or seizures, considering the conversation history if available.
            """

            chat_completion = client.chat.completions.create(
                messages=[{"role": "user", "content": epilepsy_prompt}],
                model="llama-3.3-70b-versatile",
                stream=False,
            )
            model_response = chat_completion.choices[0].message.content.strip()
        except Exception as e:
            model_response = f"โš ๏ธ Error generating response with LLaMA: {e}"

        return f"**NeuroGuard:** โœ… **Analysis:**\n{pubmedbert_insights}\n\n**Response:**\n{model_response}"


    # If Not Epilepsy Related - Try to Answer as General Health Query
    else:
        # Try Getting Medical Insights from PubMedBERT (even for general health)
        try:
            pubmedbert_embeddings = pubmedbert_pipeline(corrected_query)
            embedding_mean = np.mean(pubmedbert_embeddings[0], axis=0)
            index.add(np.array([embedding_mean]))
            pubmedbert_insights = "**PubMedBERT Analysis:**  PubMed analysis performed for health-related context." # General analysis message
        except Exception as e:
            pubmedbert_insights = f"โš ๏ธ Error during PubMedBERT analysis: {e}"

        # Use LLaMA for General Health Response Generation with Chat History Context
        try:
            prompt_history = ""
            if chat_history:
                prompt_history += "**Chat History:**\n"
                for message in chat_history:
                    prompt_history += f"{message['role'].capitalize()}: {message['content']}\n"
                prompt_history += "\n"

            general_health_prompt = f"""
            {prompt_history}
            **User Query:** {corrected_query}
            **Instructions:** Provide a concise, structured, and human-friendly response to the general health query, considering the conversation history if available. If the query is clearly not health-related, respond generally.
            """

            chat_completion = client.chat.completions.create(
                messages=[{"role": "user", "content": general_health_prompt}],
                model="llama-3.3-70b-versatile",
                stream=False,
            )
            model_response = chat_completion.choices[0].message.content.strip()
        except Exception as e:
            model_response = f"โš ๏ธ Error generating response with LLaMA: {e}"

        return f"**NeuroGuard:** โœ… **Analysis:**\n{pubmedbert_insights}\n\n**Response:**\n{model_response}"


# Streamlit UI Setup
st.set_page_config(page_title="NeuroGuard: Epilepsy & Health Chatbot", layout="wide") # Updated title
st.title("๐Ÿง  NeuroGuard: Epilepsy & Health Chatbot") # Updated title
st.write("๐Ÿ’ฌ Ask me anything about epilepsy, seizures, and general health. I remember our conversation!") # Updated description

# Initialize Chat History in Session State
if "chat_history" not in st.session_state:
    st.session_state.chat_history = []

# Display Chat History
for message in st.session_state.chat_history:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# User Input
if prompt := st.chat_input("Type your question here..."):
    st.session_state.chat_history.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.markdown(prompt)

    # Generate Bot Response
    with st.chat_message("bot"):
        with st.spinner("๐Ÿค– Thinking..."):
            try:
                response = generate_response(prompt, st.session_state.chat_history) # Pass chat history here
                st.markdown(response)
                st.session_state.chat_history.append({"role": "bot", "content": response})
            except Exception as e:
                st.error(f"โš ๏ธ Error processing query: {e}")