File size: 5,037 Bytes
a2a3f39 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import os
import streamlit as st
import pandas as pd
import faiss
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from groq import Groq
# β
Set up cache directory
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/huggingface"
# β
Load API Key
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
st.error("β Error: GROQ_API_KEY is missing. Set it as an environment variable.")
st.stop()
client = Groq(api_key=GROQ_API_KEY)
# β
Load AI Models
st.sidebar.header("Loading AI Models... Please Wait β³")
similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", cache_folder="/tmp/huggingface")
embedding_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder="/tmp/huggingface")
summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
# β
Load Datasets
try:
recommendations_df = pd.read_csv("treatment_recommendations.csv")
questions_df = pd.read_csv("symptom_questions.csv")
except FileNotFoundError as e:
st.error(f"β Missing dataset file: {e}")
st.stop()
# β
FAISS Index for Disorders
treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
index.add(treatment_embeddings)
# β
FAISS Index for Questions
question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True)
question_index = faiss.IndexFlatL2(question_embeddings.shape[1])
question_index.add(question_embeddings)
# β
Retrieve Relevant Question
def retrieve_questions(user_input):
input_embedding = embedding_model.encode([user_input], convert_to_numpy=True)
_, indices = question_index.search(input_embedding, 1)
if indices[0][0] == -1:
return "I'm sorry, I couldn't find a relevant question."
return questions_df["Questions"].iloc[indices[0][0]]
# β
Generate Empathetic Question
def generate_empathetic_response(user_input, retrieved_question):
prompt = f"""
The user said: "{user_input}"
Relevant Question:
- {retrieved_question}
You are an empathetic AI psychiatrist. Rephrase this question naturally.
Example:
- "I understand that anxiety can be overwhelming. Can you tell me more about when you started feeling this way?"
Generate only one empathetic response.
"""
try:
chat_completion = client.chat.completions.create(
messages=[{"role": "system", "content": "You are an empathetic AI psychiatrist."},
{"role": "user", "content": prompt}],
model="llama-3.3-70b-versatile",
temperature=0.8,
top_p=0.9
)
return chat_completion.choices[0].message.content
except Exception as e:
return "I'm sorry, I couldn't process your request."
# β
Disorder Detection
def detect_disorders(chat_history):
full_chat_text = " ".join(chat_history)
text_embedding = similarity_model.encode([full_chat_text], convert_to_numpy=True)
_, indices = index.search(text_embedding, 3)
if indices[0][0] == -1:
return ["No matching disorder found."]
return [recommendations_df["Disorder"].iloc[i] for i in indices[0]]
# β
Summarization
def summarize_chat(chat_history):
chat_text = " ".join(chat_history)
inputs = summarization_tokenizer("summarize: " + chat_text, return_tensors="pt", max_length=4096, truncation=True)
summary_ids = summarization_model.generate(inputs.input_ids, max_length=500, num_beams=4, early_stopping=True)
return summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
# β
UI - Streamlit Chatbot
st.title("MindSpark AI Psychiatrist π¬")
# β
Chat History
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# β
User Input
user_input = st.text_input("You:", "")
if st.button("Send"):
if user_input:
retrieved_question = retrieve_questions(user_input)
empathetic_response = generate_empathetic_response(user_input, retrieved_question)
st.session_state.chat_history.append(f"User: {user_input}")
st.session_state.chat_history.append(f"AI: {empathetic_response}")
# β
Display Chat History
st.write("### Chat History")
for msg in st.session_state.chat_history[-6:]: # Show last 6 messages
st.text(msg)
# β
Summarization & Disorder Detection
if st.button("Summarize Chat"):
summary = summarize_chat(st.session_state.chat_history)
st.write("### Chat Summary")
st.text(summary)
if st.button("Detect Disorders"):
disorders = detect_disorders(st.session_state.chat_history)
st.write("### Possible Disorders")
st.text(", ".join(disorders))
|