Spaces:
Sleeping
Sleeping
File size: 6,251 Bytes
a2a3f39 e665e5f a2a3f39 e665e5f a2a3f39 e665e5f a2a3f39 e665e5f a2a3f39 e665e5f a2a3f39 e665e5f a2a3f39 e665e5f a2a3f39 e665e5f a2a3f39 e665e5f d9adb77 990b877 d9adb77 990b877 d9adb77 990b877 d9adb77 e665e5f a2a3f39 e665e5f a2a3f39 e665e5f a2a3f39 e665e5f d9adb77 70ae897 d9adb77 70ae897 d9adb77 70ae897 d9adb77 70ae897 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import os
import streamlit as st
import pandas as pd
import faiss
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from groq import Groq
# β
Set cache directory
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/huggingface"
# β
Securely Fetch API Key
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
st.error("β GROQ_API_KEY is missing. Set it as an environment variable.")
st.stop()
client = Groq(api_key=GROQ_API_KEY)
# β
Load AI Models
similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", cache_folder="/tmp/huggingface")
embedding_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder="/tmp/huggingface")
summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
# β
Load datasets
try:
recommendations_df = pd.read_csv("treatment_recommendations.csv")
questions_df = pd.read_csv("symptom_questions.csv")
except FileNotFoundError as e:
st.error(f"β Missing dataset file: {e}")
st.stop()
# β
FAISS Index for Disorder Detection
treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
index.add(treatment_embeddings)
# β
FAISS Index for Question Retrieval
question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True)
question_index = faiss.IndexFlatL2(question_embeddings.shape[1])
question_index.add(question_embeddings)
# β
Function: Retrieve the most relevant question
def retrieve_questions(user_input):
input_embedding = embedding_model.encode([user_input], convert_to_numpy=True)
_, indices = question_index.search(input_embedding, 1)
if indices[0][0] == -1:
return "I'm sorry, I couldn't find a relevant question."
question_block = questions_df["Questions"].iloc[indices[0][0]]
return question_block.split(", ")[0] if ", " in question_block else question_block
# β
Function: Generate empathetic response using Groq API
def generate_empathetic_response(user_input, retrieved_question):
prompt = f"""
The user said: "{user_input}"
Relevant Question: - {retrieved_question}
You are an empathetic AI psychiatrist. Rephrase this question naturally in a human-like way.
"""
try:
response = client.chat.completions.create(
messages=[
{"role": "system", "content": "You are a helpful, empathetic AI psychiatrist."},
{"role": "user", "content": prompt}
],
model="llama-3.3-70b-versatile",
temperature=0.8,
top_p=0.9
)
return response.choices[0].message.content
except Exception as e:
return "I'm sorry, I couldn't process your request."
# β
Function to detect disorders
def detect_disorders(chat_history):
"""Detect psychiatric disorders from full chat history."""
if not chat_history: # β
Handle empty chat history
return ["No input provided."]
full_chat_text = " ".join(chat_history).strip()
if not full_chat_text: # β
Handle case where all messages are empty strings
return ["No meaningful text provided."]
try:
text_embedding = similarity_model.encode([full_chat_text], convert_to_numpy=True)
distances, indices = index.search(text_embedding, 3)
if indices is None or indices[0][0] == -1:
return ["No matching disorder found."]
disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]]
return disorders
except Exception as e:
return [f"Error detecting disorders: {str(e)}"] # β
Catch unexpected errors
# β
Function to get treatment recommendations
def get_treatment(detected_disorders):
"""Retrieve treatment recommendations based on detected disorders."""
treatments = {
disorder: recommendations_df[recommendations_df["Disorder"] == disorder]["Treatment Recommendation"].values[0]
for disorder in detected_disorders if disorder in recommendations_df["Disorder"].values
}
return treatments
# β
Streamlit UI Setup
st.title("π§ MindSpark AI Psychiatric Assistant")
chat_history = st.session_state.get("chat_history", [])
user_input = st.text_input("Enter your message:")
if st.button("Ask AI") and user_input:
retrieved_question = retrieve_questions(user_input)
empathetic_response = generate_empathetic_response(user_input, retrieved_question)
chat_history.append(f"User: {user_input}")
chat_history.append(f"AI: {empathetic_response}")
st.session_state["chat_history"] = chat_history
st.subheader("Chat History")
for msg in chat_history:
st.write(msg)
if st.button("Summarize Chat"):
chat_text = " ".join(chat_history)
inputs = summarization_tokenizer("summarize: " + chat_text, return_tensors="pt", max_length=4096, truncation=True)
summary_ids = summarization_model.generate(inputs.input_ids, max_length=500, num_beams=4, early_stopping=True)
summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
st.subheader("Chat Summary")
st.write(summary)
if st.button("Detect Disorders"):
if st.session_state["chat_history"]:
disorders = detect_disorders(st.session_state["chat_history"])
st.subheader("Detected Disorders:")
for disorder in disorders:
st.write(f"- {disorder}")
else:
st.error("β Please enter chat history.")
if st.button("Get Treatment Recommendations"):
if st.session_state["chat_history"]:
detected_disorders = detect_disorders(st.session_state["chat_history"])
treatments = get_treatment(detected_disorders)
st.subheader("Treatment Recommendations:")
for disorder, treatment in treatments.items():
st.write(f"**{disorder}:** {treatment}")
else:
st.error("β Please enter chat history.")
|