File size: 6,251 Bytes
a2a3f39
 
 
e665e5f
a2a3f39
 
 
 
e665e5f
a2a3f39
 
 
 
e665e5f
a2a3f39
 
e665e5f
a2a3f39
 
 
 
e665e5f
a2a3f39
 
 
 
 
e665e5f
a2a3f39
 
 
 
e665e5f
a2a3f39
 
e665e5f
a2a3f39
 
 
 
e665e5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9adb77
 
 
990b877
 
d9adb77
990b877
 
 
 
 
 
 
d9adb77
990b877
 
 
 
 
 
 
 
d9adb77
 
 
 
 
 
 
 
 
 
e665e5f
 
 
 
 
 
a2a3f39
e665e5f
 
 
 
 
 
a2a3f39
e665e5f
 
 
a2a3f39
e665e5f
 
 
 
 
 
 
d9adb77
70ae897
 
 
d9adb77
 
 
 
 
 
70ae897
 
d9adb77
70ae897
d9adb77
 
 
 
70ae897
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import streamlit as st
import pandas as pd
import faiss
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from groq import Groq

# βœ… Set cache directory
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/huggingface"

# βœ… Securely Fetch API Key
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
    st.error("❌ GROQ_API_KEY is missing. Set it as an environment variable.")
    st.stop()

client = Groq(api_key=GROQ_API_KEY)

# βœ… Load AI Models
similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", cache_folder="/tmp/huggingface")
embedding_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder="/tmp/huggingface")
summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")

# βœ… Load datasets
try:
    recommendations_df = pd.read_csv("treatment_recommendations.csv")
    questions_df = pd.read_csv("symptom_questions.csv")
except FileNotFoundError as e:
    st.error(f"❌ Missing dataset file: {e}")
    st.stop()

# βœ… FAISS Index for Disorder Detection
treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
index.add(treatment_embeddings)

# βœ… FAISS Index for Question Retrieval
question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True)
question_index = faiss.IndexFlatL2(question_embeddings.shape[1])
question_index.add(question_embeddings)

# βœ… Function: Retrieve the most relevant question
def retrieve_questions(user_input):
    input_embedding = embedding_model.encode([user_input], convert_to_numpy=True)
    _, indices = question_index.search(input_embedding, 1)
    if indices[0][0] == -1:
        return "I'm sorry, I couldn't find a relevant question."
    question_block = questions_df["Questions"].iloc[indices[0][0]]
    return question_block.split(", ")[0] if ", " in question_block else question_block

# βœ… Function: Generate empathetic response using Groq API
def generate_empathetic_response(user_input, retrieved_question):
    prompt = f"""
    The user said: "{user_input}"
    Relevant Question: - {retrieved_question}
    You are an empathetic AI psychiatrist. Rephrase this question naturally in a human-like way.
    """
    try:
        response = client.chat.completions.create(
            messages=[
                {"role": "system", "content": "You are a helpful, empathetic AI psychiatrist."},
                {"role": "user", "content": prompt}
            ],
            model="llama-3.3-70b-versatile",
            temperature=0.8,
            top_p=0.9
        )
        return response.choices[0].message.content
    except Exception as e:
        return "I'm sorry, I couldn't process your request."
# βœ… Function to detect disorders
def detect_disorders(chat_history):
    """Detect psychiatric disorders from full chat history."""
    if not chat_history:  # βœ… Handle empty chat history
        return ["No input provided."]

    full_chat_text = " ".join(chat_history).strip()
    if not full_chat_text:  # βœ… Handle case where all messages are empty strings
        return ["No meaningful text provided."]

    try:
        text_embedding = similarity_model.encode([full_chat_text], convert_to_numpy=True)
        distances, indices = index.search(text_embedding, 3)

        if indices is None or indices[0][0] == -1:
            return ["No matching disorder found."]

        disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]]
        return disorders

    except Exception as e:
        return [f"Error detecting disorders: {str(e)}"]  # βœ… Catch unexpected errors

# βœ… Function to get treatment recommendations
def get_treatment(detected_disorders):
    """Retrieve treatment recommendations based on detected disorders."""
    treatments = {
        disorder: recommendations_df[recommendations_df["Disorder"] == disorder]["Treatment Recommendation"].values[0]
        for disorder in detected_disorders if disorder in recommendations_df["Disorder"].values
    }
    return treatments


# βœ… Streamlit UI Setup
st.title("🧠 MindSpark AI Psychiatric Assistant")

chat_history = st.session_state.get("chat_history", [])
user_input = st.text_input("Enter your message:")

if st.button("Ask AI") and user_input:
    retrieved_question = retrieve_questions(user_input)
    empathetic_response = generate_empathetic_response(user_input, retrieved_question)
    chat_history.append(f"User: {user_input}")
    chat_history.append(f"AI: {empathetic_response}")
    st.session_state["chat_history"] = chat_history

st.subheader("Chat History")
for msg in chat_history:
    st.write(msg)

if st.button("Summarize Chat"):
    chat_text = " ".join(chat_history)
    inputs = summarization_tokenizer("summarize: " + chat_text, return_tensors="pt", max_length=4096, truncation=True)
    summary_ids = summarization_model.generate(inputs.input_ids, max_length=500, num_beams=4, early_stopping=True)
    summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    st.subheader("Chat Summary")
    st.write(summary)
if st.button("Detect Disorders"):
    if st.session_state["chat_history"]:
        disorders = detect_disorders(st.session_state["chat_history"])
        st.subheader("Detected Disorders:")
        for disorder in disorders:
            st.write(f"- {disorder}")
    else:
        st.error("❌ Please enter chat history.")

if st.button("Get Treatment Recommendations"):
    if st.session_state["chat_history"]:
        detected_disorders = detect_disorders(st.session_state["chat_history"])
        treatments = get_treatment(detected_disorders)
        st.subheader("Treatment Recommendations:")
        for disorder, treatment in treatments.items():
            st.write(f"**{disorder}:** {treatment}")
    else:
        st.error("❌ Please enter chat history.")