File size: 5,806 Bytes
a2a3f39
 
 
e665e5f
a2a3f39
 
 
 
e665e5f
a2a3f39
 
 
 
e665e5f
a2a3f39
 
e665e5f
a2a3f39
 
 
 
e665e5f
a2a3f39
 
 
 
 
e665e5f
a2a3f39
 
 
 
e665e5f
a2a3f39
 
e665e5f
a2a3f39
 
 
 
e665e5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9adb77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e665e5f
 
 
 
 
 
a2a3f39
e665e5f
 
 
 
 
 
a2a3f39
e665e5f
 
 
a2a3f39
e665e5f
 
 
 
 
 
 
d9adb77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import streamlit as st
import pandas as pd
import faiss
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from groq import Groq

# βœ… Set cache directory
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/huggingface"

# βœ… Securely Fetch API Key
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
    st.error("❌ GROQ_API_KEY is missing. Set it as an environment variable.")
    st.stop()

client = Groq(api_key=GROQ_API_KEY)

# βœ… Load AI Models
similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", cache_folder="/tmp/huggingface")
embedding_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder="/tmp/huggingface")
summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")

# βœ… Load datasets
try:
    recommendations_df = pd.read_csv("treatment_recommendations.csv")
    questions_df = pd.read_csv("symptom_questions.csv")
except FileNotFoundError as e:
    st.error(f"❌ Missing dataset file: {e}")
    st.stop()

# βœ… FAISS Index for Disorder Detection
treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
index.add(treatment_embeddings)

# βœ… FAISS Index for Question Retrieval
question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True)
question_index = faiss.IndexFlatL2(question_embeddings.shape[1])
question_index.add(question_embeddings)

# βœ… Function: Retrieve the most relevant question
def retrieve_questions(user_input):
    input_embedding = embedding_model.encode([user_input], convert_to_numpy=True)
    _, indices = question_index.search(input_embedding, 1)
    if indices[0][0] == -1:
        return "I'm sorry, I couldn't find a relevant question."
    question_block = questions_df["Questions"].iloc[indices[0][0]]
    return question_block.split(", ")[0] if ", " in question_block else question_block

# βœ… Function: Generate empathetic response using Groq API
def generate_empathetic_response(user_input, retrieved_question):
    prompt = f"""
    The user said: "{user_input}"
    Relevant Question: - {retrieved_question}
    You are an empathetic AI psychiatrist. Rephrase this question naturally in a human-like way.
    """
    try:
        response = client.chat.completions.create(
            messages=[
                {"role": "system", "content": "You are a helpful, empathetic AI psychiatrist."},
                {"role": "user", "content": prompt}
            ],
            model="llama-3.3-70b-versatile",
            temperature=0.8,
            top_p=0.9
        )
        return response.choices[0].message.content
    except Exception as e:
        return "I'm sorry, I couldn't process your request."
# βœ… Function to detect disorders
def detect_disorders(chat_history):
    """Detect psychiatric disorders from full chat history."""
    full_chat_text = " ".join(chat_history)
    text_embedding = similarity_model.encode([full_chat_text], convert_to_numpy=True)
    distances, indices = index.search(text_embedding, 3)

    if indices[0][0] == -1:
        return ["No matching disorder found."]

    disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]]
    return disorders

# βœ… Function to get treatment recommendations
def get_treatment(detected_disorders):
    """Retrieve treatment recommendations based on detected disorders."""
    treatments = {
        disorder: recommendations_df[recommendations_df["Disorder"] == disorder]["Treatment Recommendation"].values[0]
        for disorder in detected_disorders if disorder in recommendations_df["Disorder"].values
    }
    return treatments


# βœ… Streamlit UI Setup
st.title("🧠 MindSpark AI Psychiatric Assistant")

chat_history = st.session_state.get("chat_history", [])
user_input = st.text_input("Enter your message:")

if st.button("Ask AI") and user_input:
    retrieved_question = retrieve_questions(user_input)
    empathetic_response = generate_empathetic_response(user_input, retrieved_question)
    chat_history.append(f"User: {user_input}")
    chat_history.append(f"AI: {empathetic_response}")
    st.session_state["chat_history"] = chat_history

st.subheader("Chat History")
for msg in chat_history:
    st.write(msg)

if st.button("Summarize Chat"):
    chat_text = " ".join(chat_history)
    inputs = summarization_tokenizer("summarize: " + chat_text, return_tensors="pt", max_length=4096, truncation=True)
    summary_ids = summarization_model.generate(inputs.input_ids, max_length=500, num_beams=4, early_stopping=True)
    summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    st.subheader("Chat Summary")
    st.write(summary)
if st.button("Detect Disorders"):
    if chat_history.strip():
        disorders = detect_disorders(chat_history.split("\n"))
        st.write("**Detected Disorders:**")
        for disorder in disorders:
            st.write(f"- {disorder}")
    else:
        st.error("❌ Please enter chat history.")

if st.button("Get Treatment Recommendations"):
    if chat_history.strip():
        detected_disorders = detect_disorders(chat_history.split("\n"))
        treatments = get_treatment(detected_disorders)
        st.write("**Treatment Recommendations:**")
        for disorder, treatment in treatments.items():
            st.write(f"**{disorder}:** {treatment}")
    else:
        st.error("❌ Please enter chat history.")