File size: 2,190 Bytes
a2a3f39
 
 
ee3b51e
 
 
 
 
 
 
 
 
a2a3f39
 
 
 
ee3b51e
a2a3f39
 
 
 
ee3b51e
a2a3f39
 
ee3b51e
a2a3f39
 
 
 
ee3b51e
a2a3f39
 
 
 
 
 
ee3b51e
a2a3f39
 
 
 
ee3b51e
a2a3f39
 
ee3b51e
a2a3f39
 
 
 
ee3b51e
a2a3f39
 
 
 
 
 
 
 
 
ee3b51e
a2a3f39
 
ee3b51e
a2a3f39
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import streamlit as st
import pandas as pd
import subprocess

# Ensure FAISS is installed
try:
    import faiss
except ImportError:
    subprocess.run(["pip", "install", "faiss-cpu"])
    import faiss

from sentence_transformers import SentenceTransformer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from groq import Groq

# Set up environment variables
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/huggingface"

# Load API Key
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
    st.error("GROQ_API_KEY is missing. Set it as an environment variable.")
    st.stop()

client = Groq(api_key=GROQ_API_KEY)

# Load AI Models
st.sidebar.header("Loading AI Models... Please Wait ⏳")
similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", cache_folder="/tmp/huggingface")
embedding_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder="/tmp/huggingface")
summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")

# Load Datasets
try:
    recommendations_df = pd.read_csv("treatment_recommendations.csv")
    questions_df = pd.read_csv("symptom_questions.csv")
except FileNotFoundError as e:
    st.error(f"Missing dataset file: {e}")
    st.stop()

# FAISS Index for Disorders
treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
index.add(treatment_embeddings)

# UI - Streamlit Chatbot
st.title("MindSpark AI Psychiatrist 💬")

if "chat_history" not in st.session_state:
    st.session_state.chat_history = []

user_input = st.text_input("You:", "")
if st.button("Send"):
    if user_input:
        st.session_state.chat_history.append(f"User: {user_input}")
        st.session_state.chat_history.append(f"AI: [Response]")

st.write("### Chat History")
for msg in st.session_state.chat_history[-6:]:
    st.text(msg)