mindspark121 commited on
Commit
e665e5f
Β·
verified Β·
1 Parent(s): ee3b51e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -29
app.py CHANGED
@@ -1,64 +1,99 @@
1
  import os
2
  import streamlit as st
3
  import pandas as pd
4
- import subprocess
5
-
6
- # Ensure FAISS is installed
7
- try:
8
- import faiss
9
- except ImportError:
10
- subprocess.run(["pip", "install", "faiss-cpu"])
11
- import faiss
12
-
13
  from sentence_transformers import SentenceTransformer
14
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
15
  from groq import Groq
16
 
17
- # Set up environment variables
18
  os.environ["HF_HOME"] = "/tmp/huggingface"
19
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
20
  os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/huggingface"
21
 
22
- # Load API Key
23
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
24
  if not GROQ_API_KEY:
25
- st.error("GROQ_API_KEY is missing. Set it as an environment variable.")
26
  st.stop()
27
 
28
  client = Groq(api_key=GROQ_API_KEY)
29
 
30
- # Load AI Models
31
- st.sidebar.header("Loading AI Models... Please Wait ⏳")
32
  similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", cache_folder="/tmp/huggingface")
33
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder="/tmp/huggingface")
34
  summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
35
  summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
36
 
37
- # Load Datasets
38
  try:
39
  recommendations_df = pd.read_csv("treatment_recommendations.csv")
40
  questions_df = pd.read_csv("symptom_questions.csv")
41
  except FileNotFoundError as e:
42
- st.error(f"Missing dataset file: {e}")
43
  st.stop()
44
 
45
- # FAISS Index for Disorders
46
  treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
47
  index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
48
  index.add(treatment_embeddings)
49
 
50
- # UI - Streamlit Chatbot
51
- st.title("MindSpark AI Psychiatrist πŸ’¬")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- if "chat_history" not in st.session_state:
54
- st.session_state.chat_history = []
 
 
 
 
55
 
56
- user_input = st.text_input("You:", "")
57
- if st.button("Send"):
58
- if user_input:
59
- st.session_state.chat_history.append(f"User: {user_input}")
60
- st.session_state.chat_history.append(f"AI: [Response]")
61
 
62
- st.write("### Chat History")
63
- for msg in st.session_state.chat_history[-6:]:
64
- st.text(msg)
 
 
 
 
 
1
  import os
2
  import streamlit as st
3
  import pandas as pd
4
+ import faiss
 
 
 
 
 
 
 
 
5
  from sentence_transformers import SentenceTransformer
6
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
7
  from groq import Groq
8
 
9
+ # βœ… Set cache directory
10
  os.environ["HF_HOME"] = "/tmp/huggingface"
11
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
12
  os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/huggingface"
13
 
14
+ # βœ… Securely Fetch API Key
15
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
16
  if not GROQ_API_KEY:
17
+ st.error("❌ GROQ_API_KEY is missing. Set it as an environment variable.")
18
  st.stop()
19
 
20
  client = Groq(api_key=GROQ_API_KEY)
21
 
22
+ # βœ… Load AI Models
 
23
  similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", cache_folder="/tmp/huggingface")
24
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder="/tmp/huggingface")
25
  summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
26
  summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
27
 
28
+ # βœ… Load datasets
29
  try:
30
  recommendations_df = pd.read_csv("treatment_recommendations.csv")
31
  questions_df = pd.read_csv("symptom_questions.csv")
32
  except FileNotFoundError as e:
33
+ st.error(f"❌ Missing dataset file: {e}")
34
  st.stop()
35
 
36
+ # βœ… FAISS Index for Disorder Detection
37
  treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
38
  index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
39
  index.add(treatment_embeddings)
40
 
41
+ # βœ… FAISS Index for Question Retrieval
42
+ question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True)
43
+ question_index = faiss.IndexFlatL2(question_embeddings.shape[1])
44
+ question_index.add(question_embeddings)
45
+
46
+ # βœ… Function: Retrieve the most relevant question
47
+ def retrieve_questions(user_input):
48
+ input_embedding = embedding_model.encode([user_input], convert_to_numpy=True)
49
+ _, indices = question_index.search(input_embedding, 1)
50
+ if indices[0][0] == -1:
51
+ return "I'm sorry, I couldn't find a relevant question."
52
+ question_block = questions_df["Questions"].iloc[indices[0][0]]
53
+ return question_block.split(", ")[0] if ", " in question_block else question_block
54
+
55
+ # βœ… Function: Generate empathetic response using Groq API
56
+ def generate_empathetic_response(user_input, retrieved_question):
57
+ prompt = f"""
58
+ The user said: "{user_input}"
59
+ Relevant Question: - {retrieved_question}
60
+ You are an empathetic AI psychiatrist. Rephrase this question naturally in a human-like way.
61
+ """
62
+ try:
63
+ response = client.chat.completions.create(
64
+ messages=[
65
+ {"role": "system", "content": "You are a helpful, empathetic AI psychiatrist."},
66
+ {"role": "user", "content": prompt}
67
+ ],
68
+ model="llama-3.3-70b-versatile",
69
+ temperature=0.8,
70
+ top_p=0.9
71
+ )
72
+ return response.choices[0].message.content
73
+ except Exception as e:
74
+ return "I'm sorry, I couldn't process your request."
75
+
76
+ # βœ… Streamlit UI Setup
77
+ st.title("🧠 MindSpark AI Psychiatric Assistant")
78
+
79
+ chat_history = st.session_state.get("chat_history", [])
80
+ user_input = st.text_input("Enter your message:")
81
 
82
+ if st.button("Ask AI") and user_input:
83
+ retrieved_question = retrieve_questions(user_input)
84
+ empathetic_response = generate_empathetic_response(user_input, retrieved_question)
85
+ chat_history.append(f"User: {user_input}")
86
+ chat_history.append(f"AI: {empathetic_response}")
87
+ st.session_state["chat_history"] = chat_history
88
 
89
+ st.subheader("Chat History")
90
+ for msg in chat_history:
91
+ st.write(msg)
 
 
92
 
93
+ if st.button("Summarize Chat"):
94
+ chat_text = " ".join(chat_history)
95
+ inputs = summarization_tokenizer("summarize: " + chat_text, return_tensors="pt", max_length=4096, truncation=True)
96
+ summary_ids = summarization_model.generate(inputs.input_ids, max_length=500, num_beams=4, early_stopping=True)
97
+ summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
98
+ st.subheader("Chat Summary")
99
+ st.write(summary)