Spaces:
Sleeping
Sleeping
File size: 8,260 Bytes
385bd53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import os
import streamlit as st
import numpy as np
import faiss
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModel
from groq import Groq
# Load API Key from Environment
groq_api_key = os.environ.get("GROQ_API_KEY")
if groq_api_key is None:
st.error("GROQ_API_KEY environment variable not set.")
st.stop()
# Initialize Groq Client
try:
client = Groq(api_key=groq_api_key)
except Exception as e:
st.error(f"Error initializing Groq client: {e}")
st.stop()
# Load PubMedBERT Model (Try Groq API first, then Hugging Face)
try:
pubmedbert_tokenizer = AutoTokenizer.from_pretrained("NeuML/pubmedbert-base-embeddings")
pubmedbert_model = AutoModel.from_pretrained("NeuML/pubmedbert-base-embeddings")
pubmedbert_pipeline = pipeline('feature-extraction', model=pubmedbert_model, tokenizer=pubmedbert_tokenizer, device=-1)
except Exception:
st.warning("Error loading PubMedBERT from Groq API. Using Hugging Face model.")
pubmedbert_tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
pubmedbert_model = AutoModelForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
pubmedbert_pipeline = pipeline('feature-extraction', model=pubmedbert_model, tokenizer=pubmedbert_tokenizer, device=-1)
# Initialize FAISS Index
embedding_dim = 768
index = faiss.IndexFlatL2(embedding_dim)
# Function to Check if Query is Related to Epilepsy
def preprocess_query(query):
tokens = query.lower().split()
epilepsy_keywords = ["seizure", "epilepsy", "convulsion", "neurology", "brain activity"]
is_epilepsy_related = any(k in tokens for k in epilepsy_keywords)
return tokens, is_epilepsy_related
# Function to Generate Response with Chat History
def generate_response(user_query, chat_history):
# Grammatical Correction using LLaMA (Hidden from User)
try:
correction_prompt = f"""
Correct the following user query for grammar and spelling errors, but keep the original intent intact.
Do not add or remove any information, just fix the grammar.
User Query: {user_query}
Corrected Query:
"""
grammar_completion = client.chat.completions.create(
messages=[{"role": "user", "content": correction_prompt}],
model="llama-3.3-70b-versatile",
stream=False,
)
corrected_query = grammar_completion.choices[0].message.content.strip()
# If correction fails or returns empty, use original query
if not corrected_query:
corrected_query = user_query
except Exception as e:
corrected_query = user_query # Fallback to original query if correction fails
print(f"โ ๏ธ Grammar correction error: {e}") # Optional: Log the error for debugging
tokens, is_epilepsy_related = preprocess_query(corrected_query) # Use corrected query for processing
# Greeting Responses
greetings = ["hello", "hi", "hey"]
if any(word in tokens for word in greetings):
return "๐ Hello! How can I assist you today?"
# If Epilepsy Related - Use Epilepsy Focused Response
if is_epilepsy_related:
# Try Getting Medical Insights from PubMedBERT
try:
pubmedbert_embeddings = pubmedbert_pipeline(corrected_query) # Use corrected query for PubMedBERT
embedding_mean = np.mean(pubmedbert_embeddings[0], axis=0)
index.add(np.array([embedding_mean]))
pubmedbert_insights = "**PubMedBERT Analysis:** โ
Query is relevant to epilepsy research."
except Exception as e:
pubmedbert_insights = f"โ ๏ธ Error during PubMedBERT analysis: {e}"
# Use LLaMA for Final Response Generation with Chat History Context (Epilepsy Focus)
try:
prompt_history = ""
if chat_history:
prompt_history += "**Chat History:**\n"
for message in chat_history:
prompt_history += f"{message['role'].capitalize()}: {message['content']}\n"
prompt_history += "\n"
epilepsy_prompt = f"""
{prompt_history}
**User Query:** {corrected_query} # Use corrected query for final response generation
**Instructions:** Provide a concise, structured, and human-friendly response specifically about epilepsy or seizures, considering the conversation history if available.
"""
chat_completion = client.chat.completions.create(
messages=[{"role": "user", "content": epilepsy_prompt}],
model="llama-3.3-70b-versatile",
stream=False,
)
model_response = chat_completion.choices[0].message.content.strip()
except Exception as e:
model_response = f"โ ๏ธ Error generating response with LLaMA: {e}"
return f"**NeuroGuard:** โ
**Analysis:**\n{pubmedbert_insights}\n\n**Response:**\n{model_response}"
# If Not Epilepsy Related - Try to Answer as General Health Query
else:
# Try Getting Medical Insights from PubMedBERT (even for general health)
try:
pubmedbert_embeddings = pubmedbert_pipeline(corrected_query)
embedding_mean = np.mean(pubmedbert_embeddings[0], axis=0)
index.add(np.array([embedding_mean]))
pubmedbert_insights = "**PubMedBERT Analysis:** PubMed analysis performed for health-related context." # General analysis message
except Exception as e:
pubmedbert_insights = f"โ ๏ธ Error during PubMedBERT analysis: {e}"
# Use LLaMA for General Health Response Generation with Chat History Context
try:
prompt_history = ""
if chat_history:
prompt_history += "**Chat History:**\n"
for message in chat_history:
prompt_history += f"{message['role'].capitalize()}: {message['content']}\n"
prompt_history += "\n"
general_health_prompt = f"""
{prompt_history}
**User Query:** {corrected_query}
**Instructions:** Provide a concise, structured, and human-friendly response to the general health query, considering the conversation history if available. If the query is clearly not health-related, respond generally.
"""
chat_completion = client.chat.completions.create(
messages=[{"role": "user", "content": general_health_prompt}],
model="llama-3.3-70b-versatile",
stream=False,
)
model_response = chat_completion.choices[0].message.content.strip()
except Exception as e:
model_response = f"โ ๏ธ Error generating response with LLaMA: {e}"
return f"**NeuroGuard:** โ
**Analysis:**\n{pubmedbert_insights}\n\n**Response:**\n{model_response}"
# Streamlit UI Setup
st.set_page_config(page_title="NeuroGuard: Epilepsy & Health Chatbot", layout="wide") # Updated title
st.title("๐ง NeuroGuard: Epilepsy & Health Chatbot") # Updated title
st.write("๐ฌ Ask me anything about epilepsy, seizures, and general health. I remember our conversation!") # Updated description
# Initialize Chat History in Session State
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Display Chat History
for message in st.session_state.chat_history:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# User Input
if prompt := st.chat_input("Type your question here..."):
st.session_state.chat_history.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
# Generate Bot Response
with st.chat_message("bot"):
with st.spinner("๐ค Thinking..."):
try:
response = generate_response(prompt, st.session_state.chat_history) # Pass chat history here
st.markdown(response)
st.session_state.chat_history.append({"role": "bot", "content": response})
except Exception as e:
st.error(f"โ ๏ธ Error processing query: {e}") |