medi-call / app.py
Haseeb-001's picture
Update app.py
25920f8 verified
import os
import streamlit as st
import numpy as np
import faiss
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModel
from groq import Groq
import sounddevice as sd
import soundfile as sf
import tempfile
import whisper
# Load API Key from Environment
groq_api_key = os.environ.get("GROQ_API_KEY")
if groq_api_key is None:
st.error("GROQ_API_KEY environment variable not set.")
st.stop()
# Initialize Groq Client
try:
client = Groq(api_key=groq_api_key)
except Exception as e:
st.error(f"Error initializing Groq client: {e}")
st.stop()
# Load PubMedBERT Model (Try Groq API first, then Hugging Face)
try:
pubmedbert_tokenizer = AutoTokenizer.from_pretrained("NeuML/pubmedbert-base-embeddings")
pubmedbert_model = AutoModel.from_pretrained("NeuML/pubmedbert-base-embeddings")
pubmedbert_pipeline = pipeline('feature-extraction', model=pubmedbert_model, tokenizer=pubmedbert_tokenizer, device=-1)
except Exception:
st.warning("Error loading PubMedBERT from Groq API. Using Hugging Face model.")
pubmedbert_tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
pubmedbert_model = AutoModelForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
pubmedbert_pipeline = pipeline('feature-extraction', model=pubmedbert_model, tokenizer=pubmedbert_tokenizer, device=-1)
# Initialize FAISS Index
embedding_dim = 768
index = faiss.IndexFlatL2(embedding_dim)
# Load Whisper model
try:
whisper_model = whisper.load_model("base") # or "small", "medium", "large"
except Exception as e:
st.error(f"Error loading Whisper model: {e}")
st.stop()
# Function to Check if Query is Related to Epilepsy
def preprocess_query(query):
tokens = query.lower().split()
epilepsy_keywords = ["seizure", "epilepsy", "convulsion", "neurology", "brain activity"]
is_epilepsy_related = any(k in tokens for k in epilepsy_keywords)
return tokens, is_epilepsy_related
# Function to Generate Response
def generate_response(user_query): # Removed chat_history argument
# Grammatical Correction using LLaMA (Hidden from User)
try:
correction_prompt = f"""
Correct the following user query for grammar and spelling errors, but keep the original intent intact.
Do not add or remove any information, just fix the grammar.
User Query: {user_query}
Corrected Query:
"""
grammar_completion = client.chat.completions.create(
messages=[{"role": "user", "content": correction_prompt}],
model="llama-3.3-70b-versatile",
stream=False,
)
corrected_query = grammar_completion.choices[0].message.content.strip()
if not corrected_query:
corrected_query = user_query
except Exception as e:
corrected_query = user_query
print(f"⚠️ Grammar correction error: {e}")
tokens, is_epilepsy_related = preprocess_query(corrected_query)
# Greeting Responses
greetings = ["hello", "hi", "hey"]
if any(word in tokens for word in greetings):
return "πŸ‘‹ Hello! How can I assist you today?"
# Epilepsy Related Response
if is_epilepsy_related:
try:
pubmedbert_embeddings = pubmedbert_pipeline(corrected_query)
embedding_mean = np.mean(pubmedbert_embeddings[0], axis=0)
index.add(np.array([embedding_mean]))
pubmedbert_insights = "**PubMedBERT Analysis:** βœ… Query is relevant to epilepsy research."
except Exception as e:
pubmedbert_insights = f"⚠️ Error during PubMedBERT analysis: {e}"
try:
epilepsy_prompt = f"""
**User Query:** {corrected_query}
**Instructions:** Provide a concise, structured, and human-friendly response specifically about epilepsy or seizures.
"""
chat_completion = client.chat.completions.create(
messages=[{"role": "user", "content": epilepsy_prompt}],
model="llama-3.3-70b-versatile",
stream=False,
)
model_response = chat_completion.choices[0].message.content.strip()
except Exception as e:
model_response = f"⚠️ Error generating response with LLaMA: {e}"
return f"**NeuroGuard:** βœ… **Analysis:**\n{pubmedbert_insights}\n\n**Response:**\n{model_response}"
# General Health Response
else:
try:
pubmedbert_embeddings = pubmedbert_pipeline(corrected_query)
embedding_mean = np.mean(pubmedbert_embeddings[0], axis=0)
index.add(np.array([embedding_mean]))
pubmedbert_insights = "**PubMedBERT Analysis:** PubMed analysis performed for health-related context."
except Exception as e:
pubmedbert_insights = f"⚠️ Error during PubMedBERT analysis: {e}"
try:
general_health_prompt = f"""
**User Query:** {corrected_query}
**Instructions:** Provide a concise, structured, and human-friendly response to the general health query. If the query is clearly not health-related, respond generally.
"""
chat_completion = client.chat.completions.create(
messages=[{"role": "user", "content": general_health_prompt}],
model="llama-3.3-70b-versatile",
stream=False,
)
model_response = chat_completion.choices[0].message.content.strip()
except Exception as e:
model_response = f"⚠️ Error generating response with LLaMA: {e}"
return f"**NeuroGuard:** βœ… **Analysis:**\n{pubmedbert_insights}\n\n**Response:**\n{model_response}"
# Streamlit UI Setup
st.set_page_config(page_title="NeuroGuard: Voice Call Health Assistant", layout="wide")
st.title("πŸ“ž NeuroGuard: Voice Call Health Assistant")
st.write("πŸŽ™οΈ Click 'Start Recording', speak your question, and NeuroGuard will respond.")
bot_response_area = st.empty() # For bot response display
if st.button("Start Recording"):
st.write("Recording... Speak now.")
fs = 44100
seconds = 10
try:
myrecording = sd.rec(int(seconds * fs), samplerate=fs, channels=1)
sd.wait()
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
sf.write(temp_audio.name, myrecording, fs)
audio_file = open(temp_audio.name, 'rb')
audio_bytes = audio_file.read()
st.audio(audio_bytes, format="audio/wav")
with st.spinner("Transcribing audio..."):
temp_audio_path = temp_audio.name
transcription = whisper_model.transcribe(temp_audio_path)
user_query = transcription["text"]
st.write(f"**You said:** {user_query}")
with bot_response_area.container():
with st.chat_message("bot"):
with st.spinner("πŸ€– Thinking..."):
try:
response = generate_response(user_query) # Removed chat_history argument
st.markdown(response)
except Exception as e:
st.error(f"⚠️ Error processing query: {e}")
except Exception as e:
st.error(f"⚠️ Recording error: {e}")