File size: 4,157 Bytes
a2a3f39
 
 
e665e5f
a2a3f39
 
 
 
e665e5f
a2a3f39
 
 
 
e665e5f
a2a3f39
 
e665e5f
a2a3f39
 
 
 
e665e5f
a2a3f39
 
 
 
 
e665e5f
a2a3f39
 
 
 
e665e5f
a2a3f39
 
e665e5f
a2a3f39
 
 
 
e665e5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2a3f39
e665e5f
 
 
 
 
 
a2a3f39
e665e5f
 
 
a2a3f39
e665e5f
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import streamlit as st
import pandas as pd
import faiss
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from groq import Groq

# βœ… Set cache directory
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/huggingface"

# βœ… Securely Fetch API Key
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
    st.error("❌ GROQ_API_KEY is missing. Set it as an environment variable.")
    st.stop()

client = Groq(api_key=GROQ_API_KEY)

# βœ… Load AI Models
similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", cache_folder="/tmp/huggingface")
embedding_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder="/tmp/huggingface")
summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")

# βœ… Load datasets
try:
    recommendations_df = pd.read_csv("treatment_recommendations.csv")
    questions_df = pd.read_csv("symptom_questions.csv")
except FileNotFoundError as e:
    st.error(f"❌ Missing dataset file: {e}")
    st.stop()

# βœ… FAISS Index for Disorder Detection
treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
index.add(treatment_embeddings)

# βœ… FAISS Index for Question Retrieval
question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True)
question_index = faiss.IndexFlatL2(question_embeddings.shape[1])
question_index.add(question_embeddings)

# βœ… Function: Retrieve the most relevant question
def retrieve_questions(user_input):
    input_embedding = embedding_model.encode([user_input], convert_to_numpy=True)
    _, indices = question_index.search(input_embedding, 1)
    if indices[0][0] == -1:
        return "I'm sorry, I couldn't find a relevant question."
    question_block = questions_df["Questions"].iloc[indices[0][0]]
    return question_block.split(", ")[0] if ", " in question_block else question_block

# βœ… Function: Generate empathetic response using Groq API
def generate_empathetic_response(user_input, retrieved_question):
    prompt = f"""
    The user said: "{user_input}"
    Relevant Question: - {retrieved_question}
    You are an empathetic AI psychiatrist. Rephrase this question naturally in a human-like way.
    """
    try:
        response = client.chat.completions.create(
            messages=[
                {"role": "system", "content": "You are a helpful, empathetic AI psychiatrist."},
                {"role": "user", "content": prompt}
            ],
            model="llama-3.3-70b-versatile",
            temperature=0.8,
            top_p=0.9
        )
        return response.choices[0].message.content
    except Exception as e:
        return "I'm sorry, I couldn't process your request."

# βœ… Streamlit UI Setup
st.title("🧠 MindSpark AI Psychiatric Assistant")

chat_history = st.session_state.get("chat_history", [])
user_input = st.text_input("Enter your message:")

if st.button("Ask AI") and user_input:
    retrieved_question = retrieve_questions(user_input)
    empathetic_response = generate_empathetic_response(user_input, retrieved_question)
    chat_history.append(f"User: {user_input}")
    chat_history.append(f"AI: {empathetic_response}")
    st.session_state["chat_history"] = chat_history

st.subheader("Chat History")
for msg in chat_history:
    st.write(msg)

if st.button("Summarize Chat"):
    chat_text = " ".join(chat_history)
    inputs = summarization_tokenizer("summarize: " + chat_text, return_tensors="pt", max_length=4096, truncation=True)
    summary_ids = summarization_model.generate(inputs.input_ids, max_length=500, num_beams=4, early_stopping=True)
    summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    st.subheader("Chat Summary")
    st.write(summary)