mindspark121 commited on
Commit
a2a3f39
Β·
verified Β·
1 Parent(s): f50f657

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import faiss
5
+ from sentence_transformers import SentenceTransformer
6
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
7
+ from groq import Groq
8
+
9
+ # βœ… Set up cache directory
10
+ os.environ["HF_HOME"] = "/tmp/huggingface"
11
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
12
+ os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/huggingface"
13
+
14
+ # βœ… Load API Key
15
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
16
+ if not GROQ_API_KEY:
17
+ st.error("❌ Error: GROQ_API_KEY is missing. Set it as an environment variable.")
18
+ st.stop()
19
+
20
+ client = Groq(api_key=GROQ_API_KEY)
21
+
22
+ # βœ… Load AI Models
23
+ st.sidebar.header("Loading AI Models... Please Wait ⏳")
24
+ similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", cache_folder="/tmp/huggingface")
25
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder="/tmp/huggingface")
26
+ summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
27
+ summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base", cache_dir="/tmp/huggingface")
28
+
29
+ # βœ… Load Datasets
30
+ try:
31
+ recommendations_df = pd.read_csv("treatment_recommendations.csv")
32
+ questions_df = pd.read_csv("symptom_questions.csv")
33
+ except FileNotFoundError as e:
34
+ st.error(f"❌ Missing dataset file: {e}")
35
+ st.stop()
36
+
37
+ # βœ… FAISS Index for Disorders
38
+ treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
39
+ index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
40
+ index.add(treatment_embeddings)
41
+
42
+ # βœ… FAISS Index for Questions
43
+ question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True)
44
+ question_index = faiss.IndexFlatL2(question_embeddings.shape[1])
45
+ question_index.add(question_embeddings)
46
+
47
+ # βœ… Retrieve Relevant Question
48
+ def retrieve_questions(user_input):
49
+ input_embedding = embedding_model.encode([user_input], convert_to_numpy=True)
50
+ _, indices = question_index.search(input_embedding, 1)
51
+
52
+ if indices[0][0] == -1:
53
+ return "I'm sorry, I couldn't find a relevant question."
54
+
55
+ return questions_df["Questions"].iloc[indices[0][0]]
56
+
57
+ # βœ… Generate Empathetic Question
58
+ def generate_empathetic_response(user_input, retrieved_question):
59
+ prompt = f"""
60
+ The user said: "{user_input}"
61
+ Relevant Question:
62
+ - {retrieved_question}
63
+
64
+ You are an empathetic AI psychiatrist. Rephrase this question naturally.
65
+ Example:
66
+ - "I understand that anxiety can be overwhelming. Can you tell me more about when you started feeling this way?"
67
+
68
+ Generate only one empathetic response.
69
+ """
70
+ try:
71
+ chat_completion = client.chat.completions.create(
72
+ messages=[{"role": "system", "content": "You are an empathetic AI psychiatrist."},
73
+ {"role": "user", "content": prompt}],
74
+ model="llama-3.3-70b-versatile",
75
+ temperature=0.8,
76
+ top_p=0.9
77
+ )
78
+ return chat_completion.choices[0].message.content
79
+ except Exception as e:
80
+ return "I'm sorry, I couldn't process your request."
81
+
82
+ # βœ… Disorder Detection
83
+ def detect_disorders(chat_history):
84
+ full_chat_text = " ".join(chat_history)
85
+ text_embedding = similarity_model.encode([full_chat_text], convert_to_numpy=True)
86
+ _, indices = index.search(text_embedding, 3)
87
+
88
+ if indices[0][0] == -1:
89
+ return ["No matching disorder found."]
90
+
91
+ return [recommendations_df["Disorder"].iloc[i] for i in indices[0]]
92
+
93
+ # βœ… Summarization
94
+ def summarize_chat(chat_history):
95
+ chat_text = " ".join(chat_history)
96
+ inputs = summarization_tokenizer("summarize: " + chat_text, return_tensors="pt", max_length=4096, truncation=True)
97
+ summary_ids = summarization_model.generate(inputs.input_ids, max_length=500, num_beams=4, early_stopping=True)
98
+ return summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
99
+
100
+ # βœ… UI - Streamlit Chatbot
101
+ st.title("MindSpark AI Psychiatrist πŸ’¬")
102
+
103
+ # βœ… Chat History
104
+ if "chat_history" not in st.session_state:
105
+ st.session_state.chat_history = []
106
+
107
+ # βœ… User Input
108
+ user_input = st.text_input("You:", "")
109
+
110
+ if st.button("Send"):
111
+ if user_input:
112
+ retrieved_question = retrieve_questions(user_input)
113
+ empathetic_response = generate_empathetic_response(user_input, retrieved_question)
114
+
115
+ st.session_state.chat_history.append(f"User: {user_input}")
116
+ st.session_state.chat_history.append(f"AI: {empathetic_response}")
117
+
118
+ # βœ… Display Chat History
119
+ st.write("### Chat History")
120
+ for msg in st.session_state.chat_history[-6:]: # Show last 6 messages
121
+ st.text(msg)
122
+
123
+ # βœ… Summarization & Disorder Detection
124
+ if st.button("Summarize Chat"):
125
+ summary = summarize_chat(st.session_state.chat_history)
126
+ st.write("### Chat Summary")
127
+ st.text(summary)
128
+
129
+ if st.button("Detect Disorders"):
130
+ disorders = detect_disorders(st.session_state.chat_history)
131
+ st.write("### Possible Disorders")
132
+ st.text(", ".join(disorders))