Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
import numpy as np | |
import faiss | |
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModel | |
from groq import Groq | |
# Load API Key from Environment | |
groq_api_key = os.environ.get("GROQ_API_KEY") | |
if groq_api_key is None: | |
st.error("GROQ_API_KEY environment variable not set.") | |
st.stop() | |
# Initialize Groq Client | |
try: | |
client = Groq(api_key=groq_api_key) | |
except Exception as e: | |
st.error(f"Error initializing Groq client: {e}") | |
st.stop() | |
# Load PubMedBERT Model (Try Groq API first, then Hugging Face) | |
try: | |
pubmedbert_tokenizer = AutoTokenizer.from_pretrained("NeuML/pubmedbert-base-embeddings") | |
pubmedbert_model = AutoModel.from_pretrained("NeuML/pubmedbert-base-embeddings") | |
pubmedbert_pipeline = pipeline('feature-extraction', model=pubmedbert_model, tokenizer=pubmedbert_tokenizer, device=-1) | |
except Exception: | |
st.warning("Error loading PubMedBERT from Groq API. Using Hugging Face model.") | |
pubmedbert_tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext") | |
pubmedbert_model = AutoModelForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext") | |
pubmedbert_pipeline = pipeline('feature-extraction', model=pubmedbert_model, tokenizer=pubmedbert_tokenizer, device=-1) | |
# Initialize FAISS Index | |
embedding_dim = 768 | |
index = faiss.IndexFlatL2(embedding_dim) | |
# Function to Check if Query is Related to Epilepsy | |
def preprocess_query(query): | |
tokens = query.lower().split() | |
epilepsy_keywords = ["seizure", "epilepsy", "convulsion", "neurology", "brain activity"] | |
is_epilepsy_related = any(k in tokens for k in epilepsy_keywords) | |
return tokens, is_epilepsy_related | |
# Function to Generate Response with Chat History | |
def generate_response(user_query, chat_history): | |
# Grammatical Correction using LLaMA (Hidden from User) | |
try: | |
correction_prompt = f""" | |
Correct the following user query for grammar and spelling errors, but keep the original intent intact. | |
Do not add or remove any information, just fix the grammar. | |
User Query: {user_query} | |
Corrected Query: | |
""" | |
grammar_completion = client.chat.completions.create( | |
messages=[{"role": "user", "content": correction_prompt}], | |
model="llama-3.3-70b-versatile", | |
stream=False, | |
) | |
corrected_query = grammar_completion.choices[0].message.content.strip() | |
# If correction fails or returns empty, use original query | |
if not corrected_query: | |
corrected_query = user_query | |
except Exception as e: | |
corrected_query = user_query # Fallback to original query if correction fails | |
print(f"โ ๏ธ Grammar correction error: {e}") # Optional: Log the error for debugging | |
tokens, is_epilepsy_related = preprocess_query(corrected_query) # Use corrected query for processing | |
# Greeting Responses | |
greetings = ["hello", "hi", "hey"] | |
if any(word in tokens for word in greetings): | |
return "๐ Hello! How can I assist you today?" | |
# If Epilepsy Related - Use Epilepsy Focused Response | |
if is_epilepsy_related: | |
# Try Getting Medical Insights from PubMedBERT | |
try: | |
pubmedbert_embeddings = pubmedbert_pipeline(corrected_query) # Use corrected query for PubMedBERT | |
embedding_mean = np.mean(pubmedbert_embeddings[0], axis=0) | |
index.add(np.array([embedding_mean])) | |
pubmedbert_insights = "**PubMedBERT Analysis:** โ Query is relevant to epilepsy research." | |
except Exception as e: | |
pubmedbert_insights = f"โ ๏ธ Error during PubMedBERT analysis: {e}" | |
# Use LLaMA for Final Response Generation with Chat History Context (Epilepsy Focus) | |
try: | |
prompt_history = "" | |
if chat_history: | |
prompt_history += "**Chat History:**\n" | |
for message in chat_history: | |
prompt_history += f"{message['role'].capitalize()}: {message['content']}\n" | |
prompt_history += "\n" | |
epilepsy_prompt = f""" | |
{prompt_history} | |
**User Query:** {corrected_query} # Use corrected query for final response generation | |
**Instructions:** Provide a concise, structured, and human-friendly response specifically about epilepsy or seizures, considering the conversation history if available. | |
""" | |
chat_completion = client.chat.completions.create( | |
messages=[{"role": "user", "content": epilepsy_prompt}], | |
model="llama-3.3-70b-versatile", | |
stream=False, | |
) | |
model_response = chat_completion.choices[0].message.content.strip() | |
except Exception as e: | |
model_response = f"โ ๏ธ Error generating response with LLaMA: {e}" | |
return f"**NeuroGuard:** โ **Analysis:**\n{pubmedbert_insights}\n\n**Response:**\n{model_response}" | |
# If Not Epilepsy Related - Try to Answer as General Health Query | |
else: | |
# Try Getting Medical Insights from PubMedBERT (even for general health) | |
try: | |
pubmedbert_embeddings = pubmedbert_pipeline(corrected_query) | |
embedding_mean = np.mean(pubmedbert_embeddings[0], axis=0) | |
index.add(np.array([embedding_mean])) | |
pubmedbert_insights = "**PubMedBERT Analysis:** PubMed analysis performed for health-related context." # General analysis message | |
except Exception as e: | |
pubmedbert_insights = f"โ ๏ธ Error during PubMedBERT analysis: {e}" | |
# Use LLaMA for General Health Response Generation with Chat History Context | |
try: | |
prompt_history = "" | |
if chat_history: | |
prompt_history += "**Chat History:**\n" | |
for message in chat_history: | |
prompt_history += f"{message['role'].capitalize()}: {message['content']}\n" | |
prompt_history += "\n" | |
general_health_prompt = f""" | |
{prompt_history} | |
**User Query:** {corrected_query} | |
**Instructions:** Provide a concise, structured, and human-friendly response to the general health query, considering the conversation history if available. If the query is clearly not health-related, respond generally. | |
""" | |
chat_completion = client.chat.completions.create( | |
messages=[{"role": "user", "content": general_health_prompt}], | |
model="llama-3.3-70b-versatile", | |
stream=False, | |
) | |
model_response = chat_completion.choices[0].message.content.strip() | |
except Exception as e: | |
model_response = f"โ ๏ธ Error generating response with LLaMA: {e}" | |
return f"**NeuroGuard:** โ **Analysis:**\n{pubmedbert_insights}\n\n**Response:**\n{model_response}" | |
# Streamlit UI Setup | |
st.set_page_config(page_title="NeuroGuard: Epilepsy & Health Chatbot", layout="wide") # Updated title | |
st.title("๐ง NeuroGuard: Epilepsy & Health Chatbot") # Updated title | |
st.write("๐ฌ Ask me anything about epilepsy, seizures, and general health. I remember our conversation!") # Updated description | |
# Initialize Chat History in Session State | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [] | |
# Display Chat History | |
for message in st.session_state.chat_history: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# User Input | |
if prompt := st.chat_input("Type your question here..."): | |
st.session_state.chat_history.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# Generate Bot Response | |
with st.chat_message("bot"): | |
with st.spinner("๐ค Thinking..."): | |
try: | |
response = generate_response(prompt, st.session_state.chat_history) # Pass chat history here | |
st.markdown(response) | |
st.session_state.chat_history.append({"role": "bot", "content": response}) | |
except Exception as e: | |
st.error(f"โ ๏ธ Error processing query: {e}") |