Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,11 @@
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from transformers import pipeline
|
3 |
import torch
|
4 |
|
@@ -28,19 +35,20 @@ class AITherapistAssistant:
|
|
28 |
device=0 if torch.cuda.is_available() else -1
|
29 |
)
|
30 |
|
31 |
-
# Summarization model
|
32 |
self.summary_model = pipeline(
|
33 |
"summarization",
|
34 |
model="facebook/bart-large-cnn",
|
35 |
device=0 if torch.cuda.is_available() else -1
|
36 |
)
|
37 |
except Exception as e:
|
|
|
38 |
st.error(f"Model loading error: {e}")
|
39 |
self.conversation_model = None
|
40 |
self.summary_model = None
|
41 |
|
42 |
def detect_crisis(self, message):
|
43 |
-
"""Detect potential suicide risk in message"""
|
44 |
message_lower = message.lower()
|
45 |
for keyword in SUICIDE_KEYWORDS:
|
46 |
if keyword in message_lower:
|
@@ -48,14 +56,16 @@ class AITherapistAssistant:
|
|
48 |
return False
|
49 |
|
50 |
def generate_response(self, message):
|
51 |
-
"""Generate empathetic AI response"""
|
52 |
if not self.conversation_model:
|
53 |
return "I'm here to listen. Would you like to share more about how you're feeling?"
|
54 |
|
55 |
try:
|
56 |
# Contextual prompt to guide response
|
57 |
-
full_prompt =
|
58 |
-
|
|
|
|
|
59 |
# Generate response
|
60 |
response = self.conversation_model(
|
61 |
full_prompt,
|
@@ -65,37 +75,33 @@ class AITherapistAssistant:
|
|
65 |
|
66 |
return response
|
67 |
|
68 |
-
except Exception
|
69 |
return "I'm here to listen. Would you like to share more about how you're feeling?"
|
70 |
|
71 |
def generate_summary(self, conversation):
|
72 |
-
"""Generate a professional therapy-style summary"""
|
73 |
if not self.summary_model:
|
74 |
return "Summary generation is temporarily unavailable."
|
75 |
|
76 |
try:
|
77 |
-
# Generate summary
|
78 |
summary = self.summary_model(
|
79 |
conversation,
|
80 |
max_length=130,
|
81 |
min_length=30,
|
82 |
do_sample=False
|
83 |
)[0]['summary_text']
|
84 |
-
|
85 |
return summary
|
86 |
|
87 |
-
except Exception
|
88 |
return "Summary could not be generated."
|
89 |
|
90 |
def main():
|
91 |
-
|
92 |
-
page_title="TARS: Therapist Assistance and Response System",
|
93 |
-
page_icon="🧠"
|
94 |
-
)
|
95 |
-
|
96 |
st.title("🧠 TARS: Therapist Assistance and Response System")
|
97 |
-
st.write("A supportive space to share your feelings safely"
|
98 |
-
|
|
|
|
|
99 |
# Initialize session state
|
100 |
if 'conversation' not in st.session_state:
|
101 |
st.session_state.conversation = []
|
@@ -103,14 +109,14 @@ def main():
|
|
103 |
if 'assistant' not in st.session_state:
|
104 |
st.session_state.assistant = AITherapistAssistant()
|
105 |
|
106 |
-
#
|
107 |
for message in st.session_state.conversation:
|
108 |
if message['sender'] == 'user':
|
109 |
st.chat_message("user").write(message['text'])
|
110 |
else:
|
111 |
st.chat_message("assistant").write(message['text'])
|
112 |
|
113 |
-
# User
|
114 |
if prompt := st.chat_input("Share your thoughts. I'm here to listen."):
|
115 |
# Check for crisis indicators
|
116 |
if st.session_state.assistant.detect_crisis(prompt):
|
@@ -119,13 +125,8 @@ def main():
|
|
119 |
for org, phone in CRISIS_RESOURCES.items():
|
120 |
st.markdown(f"- {org}: `{phone}`")
|
121 |
|
122 |
-
# Add user message
|
123 |
-
st.session_state.conversation.append({
|
124 |
-
'sender': 'user',
|
125 |
-
'text': prompt
|
126 |
-
})
|
127 |
-
|
128 |
-
# Display user message
|
129 |
st.chat_message("user").write(prompt)
|
130 |
|
131 |
# Generate AI response
|
@@ -134,16 +135,13 @@ def main():
|
|
134 |
ai_response = st.session_state.assistant.generate_response(prompt)
|
135 |
st.write(ai_response)
|
136 |
|
137 |
-
# Add AI
|
138 |
-
st.session_state.conversation.append({
|
139 |
-
'sender': 'ai',
|
140 |
-
'text': ai_response
|
141 |
-
})
|
142 |
|
143 |
# Session Summary Generation
|
144 |
if st.session_state.conversation:
|
145 |
if st.button("Generate Session Summary"):
|
146 |
-
conversation_text = " ".join(
|
147 |
summary = st.session_state.assistant.generate_summary(conversation_text)
|
148 |
st.markdown("**Session Summary:**")
|
149 |
st.write(summary)
|
@@ -156,9 +154,3 @@ def main():
|
|
156 |
|
157 |
if __name__ == "__main__":
|
158 |
main()
|
159 |
-
|
160 |
-
# requirements.txt
|
161 |
-
# streamlit
|
162 |
-
# transformers
|
163 |
-
# torch
|
164 |
-
# accelerate
|
|
|
1 |
import streamlit as st
|
2 |
+
|
3 |
+
# IMPORTANT: Call set_page_config before any other Streamlit command
|
4 |
+
st.set_page_config(
|
5 |
+
page_title="TARS: Therapist Assistance and Response System",
|
6 |
+
page_icon="🧠"
|
7 |
+
)
|
8 |
+
|
9 |
from transformers import pipeline
|
10 |
import torch
|
11 |
|
|
|
35 |
device=0 if torch.cuda.is_available() else -1
|
36 |
)
|
37 |
|
38 |
+
# Summarization model
|
39 |
self.summary_model = pipeline(
|
40 |
"summarization",
|
41 |
model="facebook/bart-large-cnn",
|
42 |
device=0 if torch.cuda.is_available() else -1
|
43 |
)
|
44 |
except Exception as e:
|
45 |
+
# Display error if model loading fails
|
46 |
st.error(f"Model loading error: {e}")
|
47 |
self.conversation_model = None
|
48 |
self.summary_model = None
|
49 |
|
50 |
def detect_crisis(self, message):
|
51 |
+
"""Detect potential suicide risk in message."""
|
52 |
message_lower = message.lower()
|
53 |
for keyword in SUICIDE_KEYWORDS:
|
54 |
if keyword in message_lower:
|
|
|
56 |
return False
|
57 |
|
58 |
def generate_response(self, message):
|
59 |
+
"""Generate empathetic AI response."""
|
60 |
if not self.conversation_model:
|
61 |
return "I'm here to listen. Would you like to share more about how you're feeling?"
|
62 |
|
63 |
try:
|
64 |
# Contextual prompt to guide response
|
65 |
+
full_prompt = (
|
66 |
+
"You are a compassionate AI therapist. Respond supportively to this message: "
|
67 |
+
f"{message}. Be empathetic, validate feelings, and avoid giving direct medical advice."
|
68 |
+
)
|
69 |
# Generate response
|
70 |
response = self.conversation_model(
|
71 |
full_prompt,
|
|
|
75 |
|
76 |
return response
|
77 |
|
78 |
+
except Exception:
|
79 |
return "I'm here to listen. Would you like to share more about how you're feeling?"
|
80 |
|
81 |
def generate_summary(self, conversation):
|
82 |
+
"""Generate a professional therapy-style summary."""
|
83 |
if not self.summary_model:
|
84 |
return "Summary generation is temporarily unavailable."
|
85 |
|
86 |
try:
|
|
|
87 |
summary = self.summary_model(
|
88 |
conversation,
|
89 |
max_length=130,
|
90 |
min_length=30,
|
91 |
do_sample=False
|
92 |
)[0]['summary_text']
|
|
|
93 |
return summary
|
94 |
|
95 |
+
except Exception:
|
96 |
return "Summary could not be generated."
|
97 |
|
98 |
def main():
|
99 |
+
# Title and description
|
|
|
|
|
|
|
|
|
100 |
st.title("🧠 TARS: Therapist Assistance and Response System")
|
101 |
+
st.write("A supportive space to share your feelings safely.\n\n"
|
102 |
+
"**Disclaimer**: I am not a licensed therapist. "
|
103 |
+
"If you're in crisis, please reach out to professional help immediately.")
|
104 |
+
|
105 |
# Initialize session state
|
106 |
if 'conversation' not in st.session_state:
|
107 |
st.session_state.conversation = []
|
|
|
109 |
if 'assistant' not in st.session_state:
|
110 |
st.session_state.assistant = AITherapistAssistant()
|
111 |
|
112 |
+
# Display conversation
|
113 |
for message in st.session_state.conversation:
|
114 |
if message['sender'] == 'user':
|
115 |
st.chat_message("user").write(message['text'])
|
116 |
else:
|
117 |
st.chat_message("assistant").write(message['text'])
|
118 |
|
119 |
+
# User input with chat_input
|
120 |
if prompt := st.chat_input("Share your thoughts. I'm here to listen."):
|
121 |
# Check for crisis indicators
|
122 |
if st.session_state.assistant.detect_crisis(prompt):
|
|
|
125 |
for org, phone in CRISIS_RESOURCES.items():
|
126 |
st.markdown(f"- {org}: `{phone}`")
|
127 |
|
128 |
+
# Add user message to conversation
|
129 |
+
st.session_state.conversation.append({'sender': 'user', 'text': prompt})
|
|
|
|
|
|
|
|
|
|
|
130 |
st.chat_message("user").write(prompt)
|
131 |
|
132 |
# Generate AI response
|
|
|
135 |
ai_response = st.session_state.assistant.generate_response(prompt)
|
136 |
st.write(ai_response)
|
137 |
|
138 |
+
# Add AI response to conversation
|
139 |
+
st.session_state.conversation.append({'sender': 'ai', 'text': ai_response})
|
|
|
|
|
|
|
140 |
|
141 |
# Session Summary Generation
|
142 |
if st.session_state.conversation:
|
143 |
if st.button("Generate Session Summary"):
|
144 |
+
conversation_text = " ".join(msg['text'] for msg in st.session_state.conversation)
|
145 |
summary = st.session_state.assistant.generate_summary(conversation_text)
|
146 |
st.markdown("**Session Summary:**")
|
147 |
st.write(summary)
|
|
|
154 |
|
155 |
if __name__ == "__main__":
|
156 |
main()
|
|
|
|
|
|
|
|
|
|
|
|