Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,99 +1,164 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
try:
|
38 |
-
import torch
|
39 |
-
except ImportError:
|
40 |
-
raise ImportError("Missing dependency: torch. Install it with 'pip install torch'.")
|
41 |
-
|
42 |
-
# Optional: display a note if running on Hugging Face Spaces
|
43 |
-
if os.environ.get("SPACE_ID"):
|
44 |
-
st.info("Running on Hugging Face Spaces.")
|
45 |
-
|
46 |
-
# Load the Hugging Face summarization model (cached to avoid reloading)
|
47 |
-
@st.cache(allow_output_mutation=True)
|
48 |
-
def load_summarizer():
|
49 |
-
return pipeline("summarization", model="facebook/bart-large-cnn")
|
50 |
-
|
51 |
-
summarizer = load_summarizer()
|
52 |
-
|
53 |
-
# Initialize session state to store conversation history
|
54 |
-
if 'conversation' not in st.session_state:
|
55 |
-
st.session_state.conversation = []
|
56 |
-
|
57 |
-
# App title and disclaimer
|
58 |
-
st.title("AI Friend Therapist Assistant")
|
59 |
-
st.write(
|
60 |
-
"I'm here to listen. Share your thoughts below.\n\n"
|
61 |
-
"*Disclaimer: I am not a licensed therapist. If you're in crisis, please seek professional help immediately.*"
|
62 |
-
)
|
63 |
-
|
64 |
-
# Text area for user input
|
65 |
-
user_input = st.text_area("Your message", height=150)
|
66 |
-
|
67 |
-
# Process user input when "Send" is clicked
|
68 |
-
if st.button("Send"):
|
69 |
-
if user_input.strip():
|
70 |
-
# Append the message to conversation history
|
71 |
-
st.session_state.conversation.append(user_input)
|
72 |
-
|
73 |
-
# Define keywords that might indicate suicidal ideation or distress
|
74 |
-
flagged_keywords = [
|
75 |
-
"suicide", "kill myself", "end my life", "self-harm",
|
76 |
-
"self harm", "hopeless", "despair", "worthless", "can't go on"
|
77 |
-
]
|
78 |
-
# Check if any flagged keyword is present in the input
|
79 |
-
if any(keyword in user_input.lower() for keyword in flagged_keywords):
|
80 |
-
st.error(
|
81 |
-
"Your message indicates you may be experiencing distress. "
|
82 |
-
"If you're in immediate danger, please call **911** immediately. "
|
83 |
-
"For non-immediate crisis support in Canada, call **Crisis Services Canada at 1.833.456.4566** "
|
84 |
-
"or visit [crisisservicescanada.ca](https://www.crisisservicescanada.ca/)."
|
85 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
else:
|
87 |
-
st.
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
#
|
92 |
-
if st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
if st.session_state.conversation:
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from transformers import pipeline
|
3 |
+
import torch
|
4 |
+
|
5 |
+
# Canadian Crisis Resources
|
6 |
+
CRISIS_RESOURCES = {
|
7 |
+
"Canada Suicide Prevention Service": "1-833-456-4566",
|
8 |
+
"Crisis Services Canada": "1-833-456-4566",
|
9 |
+
"Kids Help Phone": "1-800-668-6868",
|
10 |
+
"First Nations and Inuit Hope for Wellness Help Line": "1-855-242-3310"
|
11 |
+
}
|
12 |
+
|
13 |
+
# Safety keywords for suicide risk detection
|
14 |
+
SUICIDE_KEYWORDS = [
|
15 |
+
"suicide", "kill myself", "end my life",
|
16 |
+
"want to die", "hopeless", "no way out",
|
17 |
+
"better off dead", "pain is too much"
|
18 |
+
]
|
19 |
+
|
20 |
+
class AITherapistAssistant:
|
21 |
+
def __init__(self):
|
22 |
+
# Load smaller, faster models to work within Hugging Face Spaces constraints
|
23 |
+
try:
|
24 |
+
# Conversational model
|
25 |
+
self.conversation_model = pipeline(
|
26 |
+
"text-generation",
|
27 |
+
model="microsoft/DialoGPT-small",
|
28 |
+
device=0 if torch.cuda.is_available() else -1
|
29 |
+
)
|
30 |
+
|
31 |
+
# Summarization model
|
32 |
+
self.summary_model = pipeline(
|
33 |
+
"summarization",
|
34 |
+
model="facebook/bart-small-summarization",
|
35 |
+
device=0 if torch.cuda.is_available() else -1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
)
|
37 |
+
except Exception as e:
|
38 |
+
st.error(f"Model loading error: {e}")
|
39 |
+
self.conversation_model = None
|
40 |
+
self.summary_model = None
|
41 |
+
|
42 |
+
def detect_crisis(self, message):
|
43 |
+
"""Detect potential suicide risk in message"""
|
44 |
+
message_lower = message.lower()
|
45 |
+
for keyword in SUICIDE_KEYWORDS:
|
46 |
+
if keyword in message_lower:
|
47 |
+
return True
|
48 |
+
return False
|
49 |
+
|
50 |
+
def generate_response(self, message):
|
51 |
+
"""Generate empathetic AI response"""
|
52 |
+
if not self.conversation_model:
|
53 |
+
return "I'm here to listen. Would you like to share more about how you're feeling?"
|
54 |
+
|
55 |
+
try:
|
56 |
+
# Contextual prompt to guide response
|
57 |
+
full_prompt = f"You are a compassionate AI therapist. Respond supportively to this message: {message}. Be empathetic, validate feelings, and avoid giving direct medical advice."
|
58 |
+
|
59 |
+
# Generate response
|
60 |
+
response = self.conversation_model(
|
61 |
+
full_prompt,
|
62 |
+
max_length=100,
|
63 |
+
num_return_sequences=1
|
64 |
+
)[0]['generated_text']
|
65 |
+
|
66 |
+
return response
|
67 |
+
|
68 |
+
except Exception as e:
|
69 |
+
return "I'm here to listen. Would you like to share more about how you're feeling?"
|
70 |
+
|
71 |
+
def generate_summary(self, conversation):
|
72 |
+
"""Generate a professional therapy-style summary"""
|
73 |
+
if not self.summary_model:
|
74 |
+
return "Summary generation is temporarily unavailable."
|
75 |
+
|
76 |
+
try:
|
77 |
+
# Generate summary
|
78 |
+
summary = self.summary_model(
|
79 |
+
conversation,
|
80 |
+
max_length=130,
|
81 |
+
min_length=30,
|
82 |
+
do_sample=False
|
83 |
+
)[0]['summary_text']
|
84 |
+
|
85 |
+
return summary
|
86 |
+
|
87 |
+
except Exception as e:
|
88 |
+
return "Summary could not be generated."
|
89 |
+
|
90 |
+
def main():
|
91 |
+
st.set_page_config(
|
92 |
+
page_title="SafeSpace: AI Compassionate Listener",
|
93 |
+
page_icon="🧠"
|
94 |
+
)
|
95 |
+
|
96 |
+
st.title("🧠 SafeSpace: Compassionate AI Listener")
|
97 |
+
st.write("A supportive space to share your feelings safely")
|
98 |
+
|
99 |
+
# Initialize session state
|
100 |
+
if 'conversation' not in st.session_state:
|
101 |
+
st.session_state.conversation = []
|
102 |
+
|
103 |
+
if 'assistant' not in st.session_state:
|
104 |
+
st.session_state.assistant = AITherapistAssistant()
|
105 |
+
|
106 |
+
# Conversation Display
|
107 |
+
for message in st.session_state.conversation:
|
108 |
+
if message['sender'] == 'user':
|
109 |
+
st.chat_message("user").write(message['text'])
|
110 |
else:
|
111 |
+
st.chat_message("assistant").write(message['text'])
|
112 |
+
|
113 |
+
# User Input
|
114 |
+
if prompt := st.chat_input("Share your thoughts. I'm here to listen."):
|
115 |
+
# Check for crisis indicators
|
116 |
+
if st.session_state.assistant.detect_crisis(prompt):
|
117 |
+
st.warning("⚠️ Crisis Support Detected")
|
118 |
+
st.markdown("**Immediate Support Resources:**")
|
119 |
+
for org, phone in CRISIS_RESOURCES.items():
|
120 |
+
st.markdown(f"- {org}: `{phone}`")
|
121 |
+
|
122 |
+
# Add user message
|
123 |
+
st.session_state.conversation.append({
|
124 |
+
'sender': 'user',
|
125 |
+
'text': prompt
|
126 |
+
})
|
127 |
+
|
128 |
+
# Display user message
|
129 |
+
st.chat_message("user").write(prompt)
|
130 |
+
|
131 |
+
# Generate AI response
|
132 |
+
with st.chat_message("assistant"):
|
133 |
+
with st.spinner("Listening and reflecting..."):
|
134 |
+
ai_response = st.session_state.assistant.generate_response(prompt)
|
135 |
+
st.write(ai_response)
|
136 |
+
|
137 |
+
# Add AI message
|
138 |
+
st.session_state.conversation.append({
|
139 |
+
'sender': 'ai',
|
140 |
+
'text': ai_response
|
141 |
+
})
|
142 |
+
|
143 |
+
# Session Summary Generation
|
144 |
if st.session_state.conversation:
|
145 |
+
if st.button("Generate Session Summary"):
|
146 |
+
conversation_text = " ".join([msg['text'] for msg in st.session_state.conversation])
|
147 |
+
summary = st.session_state.assistant.generate_summary(conversation_text)
|
148 |
+
st.markdown("**Session Summary:**")
|
149 |
+
st.write(summary)
|
150 |
+
|
151 |
+
# Safety Information
|
152 |
+
st.sidebar.title("🆘 Crisis Support")
|
153 |
+
st.sidebar.markdown("**Remember: You are valued and supported.**")
|
154 |
+
for org, phone in CRISIS_RESOURCES.items():
|
155 |
+
st.sidebar.markdown(f"- {org}: `{phone}`")
|
156 |
+
|
157 |
+
if __name__ == "__main__":
|
158 |
+
main()
|
159 |
+
|
160 |
+
# requirements.txt
|
161 |
+
# streamlit
|
162 |
+
# transformers
|
163 |
+
# torch
|
164 |
+
# accelerate
|