mindspark121's picture
Update app.py
ee3b51e verified
raw
history blame
2.19 kB
import os
import streamlit as st
import pandas as pd
import subprocess
# Ensure FAISS is installed
try:
import faiss
except ImportError:
subprocess.run(["pip", "install", "faiss-cpu"])
import faiss
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from groq import Groq
# Set up environment variables
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/huggingface"
# Load API Key
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
st.error("GROQ_API_KEY is missing. Set it as an environment variable.")
st.stop()
client = Groq(api_key=GROQ_API_KEY)
# Load AI Models
st.sidebar.header("Loading AI Models... Please Wait ⏳")
similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", cache_folder="/tmp/huggingface")
embedding_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder="/tmp/huggingface")
summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
# Load Datasets
try:
recommendations_df = pd.read_csv("treatment_recommendations.csv")
questions_df = pd.read_csv("symptom_questions.csv")
except FileNotFoundError as e:
st.error(f"Missing dataset file: {e}")
st.stop()
# FAISS Index for Disorders
treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
index.add(treatment_embeddings)
# UI - Streamlit Chatbot
st.title("MindSpark AI Psychiatrist πŸ’¬")
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
user_input = st.text_input("You:", "")
if st.button("Send"):
if user_input:
st.session_state.chat_history.append(f"User: {user_input}")
st.session_state.chat_history.append(f"AI: [Response]")
st.write("### Chat History")
for msg in st.session_state.chat_history[-6:]:
st.text(msg)