import os import streamlit as st import pandas as pd import faiss from sentence_transformers import SentenceTransformer from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from groq import Groq # ✅ Set cache directory os.environ["HF_HOME"] = "/tmp/huggingface" os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface" os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/huggingface" # ✅ Securely Fetch API Key GROQ_API_KEY = os.getenv("GROQ_API_KEY") if not GROQ_API_KEY: st.error("❌ GROQ_API_KEY is missing. Set it as an environment variable.") st.stop() client = Groq(api_key=GROQ_API_KEY) # ✅ Load AI Models similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", cache_folder="/tmp/huggingface") embedding_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder="/tmp/huggingface") summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface") summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface") # ✅ Load datasets try: recommendations_df = pd.read_csv("treatment_recommendations.csv") questions_df = pd.read_csv("symptom_questions.csv") except FileNotFoundError as e: st.error(f"❌ Missing dataset file: {e}") st.stop() # ✅ FAISS Index for Disorder Detection treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True) index = faiss.IndexFlatIP(treatment_embeddings.shape[1]) index.add(treatment_embeddings) # ✅ FAISS Index for Question Retrieval question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True) question_index = faiss.IndexFlatL2(question_embeddings.shape[1]) question_index.add(question_embeddings) # ✅ Function: Retrieve the most relevant question def retrieve_questions(user_input): input_embedding = embedding_model.encode([user_input], convert_to_numpy=True) _, indices = question_index.search(input_embedding, 1) if indices[0][0] == -1: return "I'm sorry, I couldn't find a relevant question." question_block = questions_df["Questions"].iloc[indices[0][0]] return question_block.split(", ")[0] if ", " in question_block else question_block # ✅ Function: Generate empathetic response using Groq API def generate_empathetic_response(user_input, retrieved_question): prompt = f""" The user said: "{user_input}" Relevant Question: - {retrieved_question} You are an empathetic AI psychiatrist. Rephrase this question naturally in a human-like way. """ try: response = client.chat.completions.create( messages=[ {"role": "system", "content": "You are a helpful, empathetic AI psychiatrist."}, {"role": "user", "content": prompt} ], model="llama-3.3-70b-versatile", temperature=0.8, top_p=0.9 ) return response.choices[0].message.content except Exception as e: return "I'm sorry, I couldn't process your request." # ✅ Function to detect disorders def detect_disorders(chat_history): """Detect psychiatric disorders from full chat history.""" if not chat_history: # ✅ Handle empty chat history return ["No input provided."] full_chat_text = " ".join(chat_history).strip() if not full_chat_text: # ✅ Handle case where all messages are empty strings return ["No meaningful text provided."] try: text_embedding = similarity_model.encode([full_chat_text], convert_to_numpy=True) distances, indices = index.search(text_embedding, 3) if indices is None or indices[0][0] == -1: return ["No matching disorder found."] disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]] return disorders except Exception as e: return [f"Error detecting disorders: {str(e)}"] # ✅ Catch unexpected errors # ✅ Function to get treatment recommendations def get_treatment(detected_disorders): """Retrieve treatment recommendations based on detected disorders.""" treatments = { disorder: recommendations_df[recommendations_df["Disorder"] == disorder]["Treatment Recommendation"].values[0] for disorder in detected_disorders if disorder in recommendations_df["Disorder"].values } return treatments # ✅ Streamlit UI Setup st.title("🧠 MindSpark AI Psychiatric Assistant") chat_history = st.session_state.get("chat_history", []) user_input = st.text_input("Enter your message:") if st.button("Ask AI") and user_input: retrieved_question = retrieve_questions(user_input) empathetic_response = generate_empathetic_response(user_input, retrieved_question) chat_history.append(f"User: {user_input}") chat_history.append(f"AI: {empathetic_response}") st.session_state["chat_history"] = chat_history st.subheader("Chat History") for msg in chat_history: st.write(msg) if st.button("Summarize Chat"): chat_text = " ".join(chat_history) inputs = summarization_tokenizer("summarize: " + chat_text, return_tensors="pt", max_length=4096, truncation=True) summary_ids = summarization_model.generate(inputs.input_ids, max_length=500, num_beams=4, early_stopping=True) summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True) st.subheader("Chat Summary") st.write(summary) if st.button("Detect Disorders"): if st.session_state["chat_history"]: disorders = detect_disorders(st.session_state["chat_history"]) st.subheader("Detected Disorders:") for disorder in disorders: st.write(f"- {disorder}") else: st.error("❌ Please enter chat history.") if st.button("Get Treatment Recommendations"): if st.session_state["chat_history"]: detected_disorders = detect_disorders(st.session_state["chat_history"]) treatments = get_treatment(detected_disorders) st.subheader("Treatment Recommendations:") for disorder, treatment in treatments.items(): st.write(f"**{disorder}:** {treatment}") else: st.error("❌ Please enter chat history.")