File size: 5,037 Bytes
a2a3f39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import streamlit as st
import pandas as pd
import faiss
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from groq import Groq

# βœ… Set up cache directory
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/huggingface"

# βœ… Load API Key
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
    st.error("❌ Error: GROQ_API_KEY is missing. Set it as an environment variable.")
    st.stop()

client = Groq(api_key=GROQ_API_KEY)

# βœ… Load AI Models
st.sidebar.header("Loading AI Models... Please Wait ⏳")
similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", cache_folder="/tmp/huggingface")
embedding_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder="/tmp/huggingface")
summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")

# βœ… Load Datasets
try:
    recommendations_df = pd.read_csv("treatment_recommendations.csv")
    questions_df = pd.read_csv("symptom_questions.csv")
except FileNotFoundError as e:
    st.error(f"❌ Missing dataset file: {e}")
    st.stop()

# βœ… FAISS Index for Disorders
treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
index.add(treatment_embeddings)

# βœ… FAISS Index for Questions
question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True)
question_index = faiss.IndexFlatL2(question_embeddings.shape[1])
question_index.add(question_embeddings)

# βœ… Retrieve Relevant Question
def retrieve_questions(user_input):
    input_embedding = embedding_model.encode([user_input], convert_to_numpy=True)
    _, indices = question_index.search(input_embedding, 1)

    if indices[0][0] == -1:
        return "I'm sorry, I couldn't find a relevant question."

    return questions_df["Questions"].iloc[indices[0][0]]

# βœ… Generate Empathetic Question
def generate_empathetic_response(user_input, retrieved_question):
    prompt = f"""
    The user said: "{user_input}"
    Relevant Question:
    - {retrieved_question}
    
    You are an empathetic AI psychiatrist. Rephrase this question naturally.
    Example:
    - "I understand that anxiety can be overwhelming. Can you tell me more about when you started feeling this way?"

    Generate only one empathetic response.
    """
    try:
        chat_completion = client.chat.completions.create(
            messages=[{"role": "system", "content": "You are an empathetic AI psychiatrist."},
                      {"role": "user", "content": prompt}],
            model="llama-3.3-70b-versatile",
            temperature=0.8,
            top_p=0.9
        )
        return chat_completion.choices[0].message.content
    except Exception as e:
        return "I'm sorry, I couldn't process your request."

# βœ… Disorder Detection
def detect_disorders(chat_history):
    full_chat_text = " ".join(chat_history)
    text_embedding = similarity_model.encode([full_chat_text], convert_to_numpy=True)
    _, indices = index.search(text_embedding, 3)

    if indices[0][0] == -1:
        return ["No matching disorder found."]

    return [recommendations_df["Disorder"].iloc[i] for i in indices[0]]

# βœ… Summarization
def summarize_chat(chat_history):
    chat_text = " ".join(chat_history)
    inputs = summarization_tokenizer("summarize: " + chat_text, return_tensors="pt", max_length=4096, truncation=True)
    summary_ids = summarization_model.generate(inputs.input_ids, max_length=500, num_beams=4, early_stopping=True)
    return summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)

# βœ… UI - Streamlit Chatbot
st.title("MindSpark AI Psychiatrist πŸ’¬")

# βœ… Chat History
if "chat_history" not in st.session_state:
    st.session_state.chat_history = []

# βœ… User Input
user_input = st.text_input("You:", "")

if st.button("Send"):
    if user_input:
        retrieved_question = retrieve_questions(user_input)
        empathetic_response = generate_empathetic_response(user_input, retrieved_question)

        st.session_state.chat_history.append(f"User: {user_input}")
        st.session_state.chat_history.append(f"AI: {empathetic_response}")

# βœ… Display Chat History
st.write("### Chat History")
for msg in st.session_state.chat_history[-6:]:  # Show last 6 messages
    st.text(msg)

# βœ… Summarization & Disorder Detection
if st.button("Summarize Chat"):
    summary = summarize_chat(st.session_state.chat_history)
    st.write("### Chat Summary")
    st.text(summary)

if st.button("Detect Disorders"):
    disorders = detect_disorders(st.session_state.chat_history)
    st.write("### Possible Disorders")
    st.text(", ".join(disorders))