Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
import pandas as pd | |
import subprocess | |
# Ensure FAISS is installed | |
try: | |
import faiss | |
except ImportError: | |
subprocess.run(["pip", "install", "faiss-cpu"]) | |
import faiss | |
from sentence_transformers import SentenceTransformer | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
from groq import Groq | |
# Set up environment variables | |
os.environ["HF_HOME"] = "/tmp/huggingface" | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface" | |
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/huggingface" | |
# Load API Key | |
GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
if not GROQ_API_KEY: | |
st.error("GROQ_API_KEY is missing. Set it as an environment variable.") | |
st.stop() | |
client = Groq(api_key=GROQ_API_KEY) | |
# Load AI Models | |
st.sidebar.header("Loading AI Models... Please Wait β³") | |
similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", cache_folder="/tmp/huggingface") | |
embedding_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder="/tmp/huggingface") | |
summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface") | |
summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface") | |
# Load Datasets | |
try: | |
recommendations_df = pd.read_csv("treatment_recommendations.csv") | |
questions_df = pd.read_csv("symptom_questions.csv") | |
except FileNotFoundError as e: | |
st.error(f"Missing dataset file: {e}") | |
st.stop() | |
# FAISS Index for Disorders | |
treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True) | |
index = faiss.IndexFlatIP(treatment_embeddings.shape[1]) | |
index.add(treatment_embeddings) | |
# UI - Streamlit Chatbot | |
st.title("MindSpark AI Psychiatrist π¬") | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [] | |
user_input = st.text_input("You:", "") | |
if st.button("Send"): | |
if user_input: | |
st.session_state.chat_history.append(f"User: {user_input}") | |
st.session_state.chat_history.append(f"AI: [Response]") | |
st.write("### Chat History") | |
for msg in st.session_state.chat_history[-6:]: | |
st.text(msg) | |