import os import streamlit as st import numpy as np import faiss from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModel from groq import Groq import sounddevice as sd import soundfile as sf import tempfile import whisper # Load API Key from Environment groq_api_key = os.environ.get("GROQ_API_KEY") if groq_api_key is None: st.error("GROQ_API_KEY environment variable not set.") st.stop() # Initialize Groq Client try: client = Groq(api_key=groq_api_key) except Exception as e: st.error(f"Error initializing Groq client: {e}") st.stop() # Load PubMedBERT Model (Try Groq API first, then Hugging Face) try: pubmedbert_tokenizer = AutoTokenizer.from_pretrained("NeuML/pubmedbert-base-embeddings") pubmedbert_model = AutoModel.from_pretrained("NeuML/pubmedbert-base-embeddings") pubmedbert_pipeline = pipeline('feature-extraction', model=pubmedbert_model, tokenizer=pubmedbert_tokenizer, device=-1) except Exception: st.warning("Error loading PubMedBERT from Groq API. Using Hugging Face model.") pubmedbert_tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext") pubmedbert_model = AutoModelForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext") pubmedbert_pipeline = pipeline('feature-extraction', model=pubmedbert_model, tokenizer=pubmedbert_tokenizer, device=-1) # Initialize FAISS Index embedding_dim = 768 index = faiss.IndexFlatL2(embedding_dim) # Load Whisper model try: whisper_model = whisper.load_model("base") # or "small", "medium", "large" except Exception as e: st.error(f"Error loading Whisper model: {e}") st.stop() # Function to Check if Query is Related to Epilepsy def preprocess_query(query): tokens = query.lower().split() epilepsy_keywords = ["seizure", "epilepsy", "convulsion", "neurology", "brain activity"] is_epilepsy_related = any(k in tokens for k in epilepsy_keywords) return tokens, is_epilepsy_related # Function to Generate Response def generate_response(user_query): # Removed chat_history argument # Grammatical Correction using LLaMA (Hidden from User) try: correction_prompt = f""" Correct the following user query for grammar and spelling errors, but keep the original intent intact. Do not add or remove any information, just fix the grammar. User Query: {user_query} Corrected Query: """ grammar_completion = client.chat.completions.create( messages=[{"role": "user", "content": correction_prompt}], model="llama-3.3-70b-versatile", stream=False, ) corrected_query = grammar_completion.choices[0].message.content.strip() if not corrected_query: corrected_query = user_query except Exception as e: corrected_query = user_query print(f"⚠️ Grammar correction error: {e}") tokens, is_epilepsy_related = preprocess_query(corrected_query) # Greeting Responses greetings = ["hello", "hi", "hey"] if any(word in tokens for word in greetings): return "👋 Hello! How can I assist you today?" # Epilepsy Related Response if is_epilepsy_related: try: pubmedbert_embeddings = pubmedbert_pipeline(corrected_query) embedding_mean = np.mean(pubmedbert_embeddings[0], axis=0) index.add(np.array([embedding_mean])) pubmedbert_insights = "**PubMedBERT Analysis:** ✅ Query is relevant to epilepsy research." except Exception as e: pubmedbert_insights = f"⚠️ Error during PubMedBERT analysis: {e}" try: epilepsy_prompt = f""" **User Query:** {corrected_query} **Instructions:** Provide a concise, structured, and human-friendly response specifically about epilepsy or seizures. """ chat_completion = client.chat.completions.create( messages=[{"role": "user", "content": epilepsy_prompt}], model="llama-3.3-70b-versatile", stream=False, ) model_response = chat_completion.choices[0].message.content.strip() except Exception as e: model_response = f"⚠️ Error generating response with LLaMA: {e}" return f"**NeuroGuard:** ✅ **Analysis:**\n{pubmedbert_insights}\n\n**Response:**\n{model_response}" # General Health Response else: try: pubmedbert_embeddings = pubmedbert_pipeline(corrected_query) embedding_mean = np.mean(pubmedbert_embeddings[0], axis=0) index.add(np.array([embedding_mean])) pubmedbert_insights = "**PubMedBERT Analysis:** PubMed analysis performed for health-related context." except Exception as e: pubmedbert_insights = f"⚠️ Error during PubMedBERT analysis: {e}" try: general_health_prompt = f""" **User Query:** {corrected_query} **Instructions:** Provide a concise, structured, and human-friendly response to the general health query. If the query is clearly not health-related, respond generally. """ chat_completion = client.chat.completions.create( messages=[{"role": "user", "content": general_health_prompt}], model="llama-3.3-70b-versatile", stream=False, ) model_response = chat_completion.choices[0].message.content.strip() except Exception as e: model_response = f"⚠️ Error generating response with LLaMA: {e}" return f"**NeuroGuard:** ✅ **Analysis:**\n{pubmedbert_insights}\n\n**Response:**\n{model_response}" # Streamlit UI Setup st.set_page_config(page_title="NeuroGuard: Voice Call Health Assistant", layout="wide") st.title("📞 NeuroGuard: Voice Call Health Assistant") st.write("🎙️ Click 'Start Recording', speak your question, and NeuroGuard will respond.") bot_response_area = st.empty() # For bot response display if st.button("Start Recording"): st.write("Recording... Speak now.") fs = 44100 seconds = 10 try: myrecording = sd.rec(int(seconds * fs), samplerate=fs, channels=1) sd.wait() with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio: sf.write(temp_audio.name, myrecording, fs) audio_file = open(temp_audio.name, 'rb') audio_bytes = audio_file.read() st.audio(audio_bytes, format="audio/wav") with st.spinner("Transcribing audio..."): temp_audio_path = temp_audio.name transcription = whisper_model.transcribe(temp_audio_path) user_query = transcription["text"] st.write(f"**You said:** {user_query}") with bot_response_area.container(): with st.chat_message("bot"): with st.spinner("🤖 Thinking..."): try: response = generate_response(user_query) # Removed chat_history argument st.markdown(response) except Exception as e: st.error(f"⚠️ Error processing query: {e}") except Exception as e: st.error(f"⚠️ Recording error: {e}")