mindspark121's picture
Create app.py
a2a3f39 verified
raw
history blame
5.04 kB
import os
import streamlit as st
import pandas as pd
import faiss
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from groq import Groq
# βœ… Set up cache directory
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/huggingface"
# βœ… Load API Key
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
st.error("❌ Error: GROQ_API_KEY is missing. Set it as an environment variable.")
st.stop()
client = Groq(api_key=GROQ_API_KEY)
# βœ… Load AI Models
st.sidebar.header("Loading AI Models... Please Wait ⏳")
similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", cache_folder="/tmp/huggingface")
embedding_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder="/tmp/huggingface")
summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
# βœ… Load Datasets
try:
recommendations_df = pd.read_csv("treatment_recommendations.csv")
questions_df = pd.read_csv("symptom_questions.csv")
except FileNotFoundError as e:
st.error(f"❌ Missing dataset file: {e}")
st.stop()
# βœ… FAISS Index for Disorders
treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
index.add(treatment_embeddings)
# βœ… FAISS Index for Questions
question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True)
question_index = faiss.IndexFlatL2(question_embeddings.shape[1])
question_index.add(question_embeddings)
# βœ… Retrieve Relevant Question
def retrieve_questions(user_input):
input_embedding = embedding_model.encode([user_input], convert_to_numpy=True)
_, indices = question_index.search(input_embedding, 1)
if indices[0][0] == -1:
return "I'm sorry, I couldn't find a relevant question."
return questions_df["Questions"].iloc[indices[0][0]]
# βœ… Generate Empathetic Question
def generate_empathetic_response(user_input, retrieved_question):
prompt = f"""
The user said: "{user_input}"
Relevant Question:
- {retrieved_question}
You are an empathetic AI psychiatrist. Rephrase this question naturally.
Example:
- "I understand that anxiety can be overwhelming. Can you tell me more about when you started feeling this way?"
Generate only one empathetic response.
"""
try:
chat_completion = client.chat.completions.create(
messages=[{"role": "system", "content": "You are an empathetic AI psychiatrist."},
{"role": "user", "content": prompt}],
model="llama-3.3-70b-versatile",
temperature=0.8,
top_p=0.9
)
return chat_completion.choices[0].message.content
except Exception as e:
return "I'm sorry, I couldn't process your request."
# βœ… Disorder Detection
def detect_disorders(chat_history):
full_chat_text = " ".join(chat_history)
text_embedding = similarity_model.encode([full_chat_text], convert_to_numpy=True)
_, indices = index.search(text_embedding, 3)
if indices[0][0] == -1:
return ["No matching disorder found."]
return [recommendations_df["Disorder"].iloc[i] for i in indices[0]]
# βœ… Summarization
def summarize_chat(chat_history):
chat_text = " ".join(chat_history)
inputs = summarization_tokenizer("summarize: " + chat_text, return_tensors="pt", max_length=4096, truncation=True)
summary_ids = summarization_model.generate(inputs.input_ids, max_length=500, num_beams=4, early_stopping=True)
return summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
# βœ… UI - Streamlit Chatbot
st.title("MindSpark AI Psychiatrist πŸ’¬")
# βœ… Chat History
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# βœ… User Input
user_input = st.text_input("You:", "")
if st.button("Send"):
if user_input:
retrieved_question = retrieve_questions(user_input)
empathetic_response = generate_empathetic_response(user_input, retrieved_question)
st.session_state.chat_history.append(f"User: {user_input}")
st.session_state.chat_history.append(f"AI: {empathetic_response}")
# βœ… Display Chat History
st.write("### Chat History")
for msg in st.session_state.chat_history[-6:]: # Show last 6 messages
st.text(msg)
# βœ… Summarization & Disorder Detection
if st.button("Summarize Chat"):
summary = summarize_chat(st.session_state.chat_history)
st.write("### Chat Summary")
st.text(summary)
if st.button("Detect Disorders"):
disorders = detect_disorders(st.session_state.chat_history)
st.write("### Possible Disorders")
st.text(", ".join(disorders))