mindspark121 commited on
Commit
ee3b51e
Β·
verified Β·
1 Parent(s): 012a1d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -87
app.py CHANGED
@@ -1,132 +1,64 @@
1
  import os
2
  import streamlit as st
3
  import pandas as pd
4
- import faiss
 
 
 
 
 
 
 
 
5
  from sentence_transformers import SentenceTransformer
6
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
7
  from groq import Groq
8
 
9
- # βœ… Set up cache directory
10
  os.environ["HF_HOME"] = "/tmp/huggingface"
11
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
12
  os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/huggingface"
13
 
14
- # βœ… Load API Key
15
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
16
  if not GROQ_API_KEY:
17
- st.error("❌ Error: GROQ_API_KEY is missing. Set it as an environment variable.")
18
  st.stop()
19
 
20
  client = Groq(api_key=GROQ_API_KEY)
21
 
22
- # βœ… Load AI Models
23
  st.sidebar.header("Loading AI Models... Please Wait ⏳")
24
  similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", cache_folder="/tmp/huggingface")
25
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder="/tmp/huggingface")
26
  summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
27
  summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
28
 
29
- # βœ… Load Datasets
30
  try:
31
  recommendations_df = pd.read_csv("treatment_recommendations.csv")
32
  questions_df = pd.read_csv("symptom_questions.csv")
33
  except FileNotFoundError as e:
34
- st.error(f"❌ Missing dataset file: {e}")
35
  st.stop()
36
 
37
- # βœ… FAISS Index for Disorders
38
  treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
39
  index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
40
  index.add(treatment_embeddings)
41
 
42
- # βœ… FAISS Index for Questions
43
- question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True)
44
- question_index = faiss.IndexFlatL2(question_embeddings.shape[1])
45
- question_index.add(question_embeddings)
46
-
47
- # βœ… Retrieve Relevant Question
48
- def retrieve_questions(user_input):
49
- input_embedding = embedding_model.encode([user_input], convert_to_numpy=True)
50
- _, indices = question_index.search(input_embedding, 1)
51
-
52
- if indices[0][0] == -1:
53
- return "I'm sorry, I couldn't find a relevant question."
54
-
55
- return questions_df["Questions"].iloc[indices[0][0]]
56
-
57
- # βœ… Generate Empathetic Question
58
- def generate_empathetic_response(user_input, retrieved_question):
59
- prompt = f"""
60
- The user said: "{user_input}"
61
- Relevant Question:
62
- - {retrieved_question}
63
-
64
- You are an empathetic AI psychiatrist. Rephrase this question naturally.
65
- Example:
66
- - "I understand that anxiety can be overwhelming. Can you tell me more about when you started feeling this way?"
67
-
68
- Generate only one empathetic response.
69
- """
70
- try:
71
- chat_completion = client.chat.completions.create(
72
- messages=[{"role": "system", "content": "You are an empathetic AI psychiatrist."},
73
- {"role": "user", "content": prompt}],
74
- model="llama-3.3-70b-versatile",
75
- temperature=0.8,
76
- top_p=0.9
77
- )
78
- return chat_completion.choices[0].message.content
79
- except Exception as e:
80
- return "I'm sorry, I couldn't process your request."
81
-
82
- # βœ… Disorder Detection
83
- def detect_disorders(chat_history):
84
- full_chat_text = " ".join(chat_history)
85
- text_embedding = similarity_model.encode([full_chat_text], convert_to_numpy=True)
86
- _, indices = index.search(text_embedding, 3)
87
-
88
- if indices[0][0] == -1:
89
- return ["No matching disorder found."]
90
-
91
- return [recommendations_df["Disorder"].iloc[i] for i in indices[0]]
92
-
93
- # βœ… Summarization
94
- def summarize_chat(chat_history):
95
- chat_text = " ".join(chat_history)
96
- inputs = summarization_tokenizer("summarize: " + chat_text, return_tensors="pt", max_length=4096, truncation=True)
97
- summary_ids = summarization_model.generate(inputs.input_ids, max_length=500, num_beams=4, early_stopping=True)
98
- return summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
99
-
100
- # βœ… UI - Streamlit Chatbot
101
  st.title("MindSpark AI Psychiatrist πŸ’¬")
102
 
103
- # βœ… Chat History
104
  if "chat_history" not in st.session_state:
105
  st.session_state.chat_history = []
106
 
107
- # βœ… User Input
108
  user_input = st.text_input("You:", "")
109
-
110
  if st.button("Send"):
111
  if user_input:
112
- retrieved_question = retrieve_questions(user_input)
113
- empathetic_response = generate_empathetic_response(user_input, retrieved_question)
114
-
115
  st.session_state.chat_history.append(f"User: {user_input}")
116
- st.session_state.chat_history.append(f"AI: {empathetic_response}")
117
 
118
- # βœ… Display Chat History
119
  st.write("### Chat History")
120
- for msg in st.session_state.chat_history[-6:]: # Show last 6 messages
121
  st.text(msg)
122
-
123
- # βœ… Summarization & Disorder Detection
124
- if st.button("Summarize Chat"):
125
- summary = summarize_chat(st.session_state.chat_history)
126
- st.write("### Chat Summary")
127
- st.text(summary)
128
-
129
- if st.button("Detect Disorders"):
130
- disorders = detect_disorders(st.session_state.chat_history)
131
- st.write("### Possible Disorders")
132
- st.text(", ".join(disorders))
 
1
  import os
2
  import streamlit as st
3
  import pandas as pd
4
+ import subprocess
5
+
6
+ # Ensure FAISS is installed
7
+ try:
8
+ import faiss
9
+ except ImportError:
10
+ subprocess.run(["pip", "install", "faiss-cpu"])
11
+ import faiss
12
+
13
  from sentence_transformers import SentenceTransformer
14
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
15
  from groq import Groq
16
 
17
+ # Set up environment variables
18
  os.environ["HF_HOME"] = "/tmp/huggingface"
19
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
20
  os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/huggingface"
21
 
22
+ # Load API Key
23
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
24
  if not GROQ_API_KEY:
25
+ st.error("GROQ_API_KEY is missing. Set it as an environment variable.")
26
  st.stop()
27
 
28
  client = Groq(api_key=GROQ_API_KEY)
29
 
30
+ # Load AI Models
31
  st.sidebar.header("Loading AI Models... Please Wait ⏳")
32
  similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", cache_folder="/tmp/huggingface")
33
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder="/tmp/huggingface")
34
  summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
35
  summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
36
 
37
+ # Load Datasets
38
  try:
39
  recommendations_df = pd.read_csv("treatment_recommendations.csv")
40
  questions_df = pd.read_csv("symptom_questions.csv")
41
  except FileNotFoundError as e:
42
+ st.error(f"Missing dataset file: {e}")
43
  st.stop()
44
 
45
+ # FAISS Index for Disorders
46
  treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
47
  index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
48
  index.add(treatment_embeddings)
49
 
50
+ # UI - Streamlit Chatbot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  st.title("MindSpark AI Psychiatrist πŸ’¬")
52
 
 
53
  if "chat_history" not in st.session_state:
54
  st.session_state.chat_history = []
55
 
 
56
  user_input = st.text_input("You:", "")
 
57
  if st.button("Send"):
58
  if user_input:
 
 
 
59
  st.session_state.chat_history.append(f"User: {user_input}")
60
+ st.session_state.chat_history.append(f"AI: [Response]")
61
 
 
62
  st.write("### Chat History")
63
+ for msg in st.session_state.chat_history[-6:]:
64
  st.text(msg)