Spaces:
Running
Running
import os | |
import streamlit as st | |
import numpy as np | |
import faiss | |
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModel | |
from groq import Groq | |
import sounddevice as sd | |
import soundfile as sf | |
import tempfile | |
import whisper | |
# Load API Key from Environment | |
groq_api_key = os.environ.get("GROQ_API_KEY") | |
if groq_api_key is None: | |
st.error("GROQ_API_KEY environment variable not set.") | |
st.stop() | |
# Initialize Groq Client | |
try: | |
client = Groq(api_key=groq_api_key) | |
except Exception as e: | |
st.error(f"Error initializing Groq client: {e}") | |
st.stop() | |
# Load PubMedBERT Model (Try Groq API first, then Hugging Face) | |
try: | |
pubmedbert_tokenizer = AutoTokenizer.from_pretrained("NeuML/pubmedbert-base-embeddings") | |
pubmedbert_model = AutoModel.from_pretrained("NeuML/pubmedbert-base-embeddings") | |
pubmedbert_pipeline = pipeline('feature-extraction', model=pubmedbert_model, tokenizer=pubmedbert_tokenizer, device=-1) | |
except Exception: | |
st.warning("Error loading PubMedBERT from Groq API. Using Hugging Face model.") | |
pubmedbert_tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext") | |
pubmedbert_model = AutoModelForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext") | |
pubmedbert_pipeline = pipeline('feature-extraction', model=pubmedbert_model, tokenizer=pubmedbert_tokenizer, device=-1) | |
# Initialize FAISS Index | |
embedding_dim = 768 | |
index = faiss.IndexFlatL2(embedding_dim) | |
# Load Whisper model | |
try: | |
whisper_model = whisper.load_model("base") # or "small", "medium", "large" | |
except Exception as e: | |
st.error(f"Error loading Whisper model: {e}") | |
st.stop() | |
# Function to Check if Query is Related to Epilepsy | |
def preprocess_query(query): | |
tokens = query.lower().split() | |
epilepsy_keywords = ["seizure", "epilepsy", "convulsion", "neurology", "brain activity"] | |
is_epilepsy_related = any(k in tokens for k in epilepsy_keywords) | |
return tokens, is_epilepsy_related | |
# Function to Generate Response | |
def generate_response(user_query): # Removed chat_history argument | |
# Grammatical Correction using LLaMA (Hidden from User) | |
try: | |
correction_prompt = f""" | |
Correct the following user query for grammar and spelling errors, but keep the original intent intact. | |
Do not add or remove any information, just fix the grammar. | |
User Query: {user_query} | |
Corrected Query: | |
""" | |
grammar_completion = client.chat.completions.create( | |
messages=[{"role": "user", "content": correction_prompt}], | |
model="llama-3.3-70b-versatile", | |
stream=False, | |
) | |
corrected_query = grammar_completion.choices[0].message.content.strip() | |
if not corrected_query: | |
corrected_query = user_query | |
except Exception as e: | |
corrected_query = user_query | |
print(f"β οΈ Grammar correction error: {e}") | |
tokens, is_epilepsy_related = preprocess_query(corrected_query) | |
# Greeting Responses | |
greetings = ["hello", "hi", "hey"] | |
if any(word in tokens for word in greetings): | |
return "π Hello! How can I assist you today?" | |
# Epilepsy Related Response | |
if is_epilepsy_related: | |
try: | |
pubmedbert_embeddings = pubmedbert_pipeline(corrected_query) | |
embedding_mean = np.mean(pubmedbert_embeddings[0], axis=0) | |
index.add(np.array([embedding_mean])) | |
pubmedbert_insights = "**PubMedBERT Analysis:** β Query is relevant to epilepsy research." | |
except Exception as e: | |
pubmedbert_insights = f"β οΈ Error during PubMedBERT analysis: {e}" | |
try: | |
epilepsy_prompt = f""" | |
**User Query:** {corrected_query} | |
**Instructions:** Provide a concise, structured, and human-friendly response specifically about epilepsy or seizures. | |
""" | |
chat_completion = client.chat.completions.create( | |
messages=[{"role": "user", "content": epilepsy_prompt}], | |
model="llama-3.3-70b-versatile", | |
stream=False, | |
) | |
model_response = chat_completion.choices[0].message.content.strip() | |
except Exception as e: | |
model_response = f"β οΈ Error generating response with LLaMA: {e}" | |
return f"**NeuroGuard:** β **Analysis:**\n{pubmedbert_insights}\n\n**Response:**\n{model_response}" | |
# General Health Response | |
else: | |
try: | |
pubmedbert_embeddings = pubmedbert_pipeline(corrected_query) | |
embedding_mean = np.mean(pubmedbert_embeddings[0], axis=0) | |
index.add(np.array([embedding_mean])) | |
pubmedbert_insights = "**PubMedBERT Analysis:** PubMed analysis performed for health-related context." | |
except Exception as e: | |
pubmedbert_insights = f"β οΈ Error during PubMedBERT analysis: {e}" | |
try: | |
general_health_prompt = f""" | |
**User Query:** {corrected_query} | |
**Instructions:** Provide a concise, structured, and human-friendly response to the general health query. If the query is clearly not health-related, respond generally. | |
""" | |
chat_completion = client.chat.completions.create( | |
messages=[{"role": "user", "content": general_health_prompt}], | |
model="llama-3.3-70b-versatile", | |
stream=False, | |
) | |
model_response = chat_completion.choices[0].message.content.strip() | |
except Exception as e: | |
model_response = f"β οΈ Error generating response with LLaMA: {e}" | |
return f"**NeuroGuard:** β **Analysis:**\n{pubmedbert_insights}\n\n**Response:**\n{model_response}" | |
# Streamlit UI Setup | |
st.set_page_config(page_title="NeuroGuard: Voice Call Health Assistant", layout="wide") | |
st.title("π NeuroGuard: Voice Call Health Assistant") | |
st.write("ποΈ Click 'Start Recording', speak your question, and NeuroGuard will respond.") | |
bot_response_area = st.empty() # For bot response display | |
if st.button("Start Recording"): | |
st.write("Recording... Speak now.") | |
fs = 44100 | |
seconds = 10 | |
try: | |
myrecording = sd.rec(int(seconds * fs), samplerate=fs, channels=1) | |
sd.wait() | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio: | |
sf.write(temp_audio.name, myrecording, fs) | |
audio_file = open(temp_audio.name, 'rb') | |
audio_bytes = audio_file.read() | |
st.audio(audio_bytes, format="audio/wav") | |
with st.spinner("Transcribing audio..."): | |
temp_audio_path = temp_audio.name | |
transcription = whisper_model.transcribe(temp_audio_path) | |
user_query = transcription["text"] | |
st.write(f"**You said:** {user_query}") | |
with bot_response_area.container(): | |
with st.chat_message("bot"): | |
with st.spinner("π€ Thinking..."): | |
try: | |
response = generate_response(user_query) # Removed chat_history argument | |
st.markdown(response) | |
except Exception as e: | |
st.error(f"β οΈ Error processing query: {e}") | |
except Exception as e: | |
st.error(f"β οΈ Recording error: {e}") |