import os import streamlit as st import pandas as pd import subprocess # Ensure FAISS is installed try: import faiss except ImportError: subprocess.run(["pip", "install", "faiss-cpu"]) import faiss from sentence_transformers import SentenceTransformer from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from groq import Groq # Set up environment variables os.environ["HF_HOME"] = "/tmp/huggingface" os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface" os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/huggingface" # Load API Key GROQ_API_KEY = os.getenv("GROQ_API_KEY") if not GROQ_API_KEY: st.error("GROQ_API_KEY is missing. Set it as an environment variable.") st.stop() client = Groq(api_key=GROQ_API_KEY) # Load AI Models st.sidebar.header("Loading AI Models... Please Wait ⏳") similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", cache_folder="/tmp/huggingface") embedding_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder="/tmp/huggingface") summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface") summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface") # Load Datasets try: recommendations_df = pd.read_csv("treatment_recommendations.csv") questions_df = pd.read_csv("symptom_questions.csv") except FileNotFoundError as e: st.error(f"Missing dataset file: {e}") st.stop() # FAISS Index for Disorders treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True) index = faiss.IndexFlatIP(treatment_embeddings.shape[1]) index.add(treatment_embeddings) # UI - Streamlit Chatbot st.title("MindSpark AI Psychiatrist 💬") if "chat_history" not in st.session_state: st.session_state.chat_history = [] user_input = st.text_input("You:", "") if st.button("Send"): if user_input: st.session_state.chat_history.append(f"User: {user_input}") st.session_state.chat_history.append(f"AI: [Response]") st.write("### Chat History") for msg in st.session_state.chat_history[-6:]: st.text(msg)