Divymakesml commited on
Commit
5294bbb
·
verified ·
1 Parent(s): 04d03d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -81
app.py CHANGED
@@ -1,15 +1,16 @@
1
  import streamlit as st
2
 
3
- # IMPORTANT: Call set_page_config before any other Streamlit command
4
  st.set_page_config(
5
  page_title="TARS: Therapist Assistance and Response System",
6
  page_icon="🧠"
7
  )
8
 
9
- from transformers import pipeline
10
  import torch
 
11
 
12
- # Canadian Crisis Resources
13
  CRISIS_RESOURCES = {
14
  "Canada Suicide Prevention Service": "1-833-456-4566",
15
  "Crisis Services Canada": "1-833-456-4566",
@@ -17,7 +18,7 @@ CRISIS_RESOURCES = {
17
  "First Nations and Inuit Hope for Wellness Help Line": "1-855-242-3310"
18
  }
19
 
20
- # Safety keywords for suicide risk detection
21
  SUICIDE_KEYWORDS = [
22
  "suicide", "kill myself", "end my life",
23
  "want to die", "hopeless", "no way out",
@@ -25,132 +26,157 @@ SUICIDE_KEYWORDS = [
25
  ]
26
 
27
  class AITherapistAssistant:
28
- def __init__(self):
29
- # Load smaller, faster models to work within Hugging Face Spaces constraints
 
 
 
 
 
 
30
  try:
31
- # Conversational model
32
  self.conversation_model = pipeline(
33
- "text-generation",
34
- model="microsoft/DialoGPT-small",
35
  device=0 if torch.cuda.is_available() else -1
36
  )
37
-
38
- # Summarization model
 
 
 
 
39
  self.summary_model = pipeline(
40
  "summarization",
41
- model="facebook/bart-large-cnn",
42
  device=0 if torch.cuda.is_available() else -1
43
  )
44
  except Exception as e:
45
- # Display error if model loading fails
46
- st.error(f"Model loading error: {e}")
47
- self.conversation_model = None
48
  self.summary_model = None
49
 
50
- def detect_crisis(self, message):
51
- """Detect potential suicide risk in message."""
52
  message_lower = message.lower()
53
- for keyword in SUICIDE_KEYWORDS:
54
- if keyword in message_lower:
55
- return True
56
- return False
57
 
58
- def generate_response(self, message):
59
- """Generate empathetic AI response."""
60
  if not self.conversation_model:
61
- return "I'm here to listen. Would you like to share more about how you're feeling?"
 
 
 
 
 
 
 
 
 
 
62
 
63
  try:
64
- # Contextual prompt to guide response
65
- full_prompt = (
66
- "You are a compassionate AI therapist. Respond supportively to this message: "
67
- f"{message}. Be empathetic, validate feelings, and avoid giving direct medical advice."
 
 
 
68
  )
69
- # Generate response
70
- response = self.conversation_model(
71
- full_prompt,
72
- max_length=100,
73
- num_return_sequences=1
74
- )[0]['generated_text']
75
 
76
- return response
77
-
78
- except Exception:
79
- return "I'm here to listen. Would you like to share more about how you're feeling?"
 
 
 
 
80
 
81
- def generate_summary(self, conversation):
82
- """Generate a professional therapy-style summary."""
83
  if not self.summary_model:
84
- return "Summary generation is temporarily unavailable."
85
 
86
  try:
87
- summary = self.summary_model(
88
- conversation,
89
  max_length=130,
90
  min_length=30,
91
  do_sample=False
92
- )[0]['summary_text']
93
- return summary
94
-
95
- except Exception:
96
- return "Summary could not be generated."
97
 
98
  def main():
99
- # Title and description
100
  st.title("🧠 TARS: Therapist Assistance and Response System")
101
- st.write("A supportive space to share your feelings safely.\n\n"
102
- "**Disclaimer**: I am not a licensed therapist. "
103
- "If you're in crisis, please reach out to professional help immediately.")
104
-
105
- # Initialize session state
106
- if 'conversation' not in st.session_state:
107
- st.session_state.conversation = []
 
 
 
 
 
 
 
 
 
108
 
109
- if 'assistant' not in st.session_state:
110
- st.session_state.assistant = AITherapistAssistant()
 
111
 
112
- # Display conversation
113
  for message in st.session_state.conversation:
114
- if message['sender'] == 'user':
115
- st.chat_message("user").write(message['text'])
116
  else:
117
- st.chat_message("assistant").write(message['text'])
118
 
119
- # User input with chat_input
120
- if prompt := st.chat_input("Share your thoughts. I'm here to listen."):
121
- # Check for crisis indicators
122
  if st.session_state.assistant.detect_crisis(prompt):
123
- st.warning("⚠️ Crisis Support Detected")
124
- st.markdown("**Immediate Support Resources:**")
125
  for org, phone in CRISIS_RESOURCES.items():
126
  st.markdown(f"- {org}: `{phone}`")
127
 
128
- # Add user message to conversation
129
- st.session_state.conversation.append({'sender': 'user', 'text': prompt})
130
  st.chat_message("user").write(prompt)
131
 
132
  # Generate AI response
133
  with st.chat_message("assistant"):
134
- with st.spinner("Listening and reflecting..."):
135
  ai_response = st.session_state.assistant.generate_response(prompt)
136
  st.write(ai_response)
137
-
138
- # Add AI response to conversation
139
- st.session_state.conversation.append({'sender': 'ai', 'text': ai_response})
140
 
141
- # Session Summary Generation
142
- if st.session_state.conversation:
143
- if st.button("Generate Session Summary"):
144
- conversation_text = " ".join(msg['text'] for msg in st.session_state.conversation)
145
  summary = st.session_state.assistant.generate_summary(conversation_text)
146
- st.markdown("**Session Summary:**")
147
  st.write(summary)
 
 
148
 
149
- # Safety Information
150
  st.sidebar.title("🆘 Crisis Support")
151
- st.sidebar.markdown("**Remember: You are valued and supported.**")
152
  for org, phone in CRISIS_RESOURCES.items():
153
- st.sidebar.markdown(f"- {org}: `{phone}`")
154
 
155
  if __name__ == "__main__":
156
  main()
 
1
  import streamlit as st
2
 
3
+ # 1. Set page config FIRST
4
  st.set_page_config(
5
  page_title="TARS: Therapist Assistance and Response System",
6
  page_icon="🧠"
7
  )
8
 
9
+ import os
10
  import torch
11
+ from transformers import pipeline
12
 
13
+ # Canadian Crisis Resources (example)
14
  CRISIS_RESOURCES = {
15
  "Canada Suicide Prevention Service": "1-833-456-4566",
16
  "Crisis Services Canada": "1-833-456-4566",
 
18
  "First Nations and Inuit Hope for Wellness Help Line": "1-855-242-3310"
19
  }
20
 
21
+ # Keywords for detecting potential self-harm or suicidal language
22
  SUICIDE_KEYWORDS = [
23
  "suicide", "kill myself", "end my life",
24
  "want to die", "hopeless", "no way out",
 
26
  ]
27
 
28
  class AITherapistAssistant:
29
+ def __init__(self, conversation_model_name="microsoft/phi-1_5", summary_model_name="facebook/bart-large-cnn"):
30
+ """
31
+ Initialize the conversation (LLM) model and the summarization model.
32
+
33
+ If you truly have a different 'phi2' from Microsoft, replace 'microsoft/phi-1_5'
34
+ with your private or custom Hugging Face repo name.
35
+ """
36
+ # Load conversation LLM (phi2 / phi-1_5)
37
  try:
 
38
  self.conversation_model = pipeline(
39
+ "text-generation",
40
+ model=conversation_model_name,
41
  device=0 if torch.cuda.is_available() else -1
42
  )
43
+ except Exception as e:
44
+ st.error(f"Error loading conversation model: {e}")
45
+ self.conversation_model = None
46
+
47
+ # Load summarization model (BART Large CNN as default)
48
+ try:
49
  self.summary_model = pipeline(
50
  "summarization",
51
+ model=summary_model_name,
52
  device=0 if torch.cuda.is_available() else -1
53
  )
54
  except Exception as e:
55
+ st.error(f"Error loading summary model: {e}")
 
 
56
  self.summary_model = None
57
 
58
+ def detect_crisis(self, message: str) -> bool:
59
+ """Check if message contains suicidal or distress-related keywords."""
60
  message_lower = message.lower()
61
+ return any(keyword in message_lower for keyword in SUICIDE_KEYWORDS)
 
 
 
62
 
63
+ def generate_response(self, message: str) -> str:
64
+ """Generate a supportive AI response from the conversation model."""
65
  if not self.conversation_model:
66
+ return (
67
+ "I'm here to listen, but I'm currently having trouble loading my AI model. "
68
+ "Please try again later."
69
+ )
70
+
71
+ # Prompt to steer the model toward empathy and support
72
+ prompt = (
73
+ "You are a compassionate AI therapist. Respond supportively to this message:\n"
74
+ f"{message}\n\n"
75
+ "Be empathetic, validate feelings, and avoid giving direct medical advice."
76
+ )
77
 
78
  try:
79
+ outputs = self.conversation_model(
80
+ prompt,
81
+ max_length=250,
82
+ num_return_sequences=1,
83
+ do_sample=True,
84
+ top_p=0.9,
85
+ temperature=0.7
86
  )
87
+ response_text = outputs[0]["generated_text"]
 
 
 
 
 
88
 
89
+ # If the model echoes the prompt, strip it out:
90
+ if response_text.startswith(prompt):
91
+ response_text = response_text[len(prompt):].strip()
92
+
93
+ return response_text.strip()
94
+ except Exception as e:
95
+ st.error(f"Error generating response: {e}")
96
+ return "I'm sorry, but I'm having trouble responding right now."
97
 
98
+ def generate_summary(self, conversation_text: str) -> str:
99
+ """Generate a short summary of the entire conversation."""
100
  if not self.summary_model:
101
+ return "Summary model is unavailable at the moment."
102
 
103
  try:
104
+ summary_output = self.summary_model(
105
+ conversation_text,
106
  max_length=130,
107
  min_length=30,
108
  do_sample=False
109
+ )
110
+ return summary_output[0]["summary_text"]
111
+ except Exception as e:
112
+ st.error(f"Error generating summary: {e}")
113
+ return "Sorry, I couldn't generate a summary."
114
 
115
  def main():
 
116
  st.title("🧠 TARS: Therapist Assistance and Response System")
117
+ st.write(
118
+ "A supportive space to share your feelings safely.\n\n"
119
+ "**Disclaimer**: I am not a licensed therapist. If you're in crisis, "
120
+ "please reach out to professional help immediately."
121
+ )
122
+
123
+ # Note if running on Hugging Face Spaces
124
+ if os.environ.get("SPACE_ID"):
125
+ st.info("Running on Hugging Face Spaces.")
126
+
127
+ # Instantiate the assistant with phi-1_5 (or your custom 'phi2')
128
+ if "assistant" not in st.session_state:
129
+ st.session_state.assistant = AITherapistAssistant(
130
+ conversation_model_name="microsoft/phi-1_5", # replace if needed
131
+ summary_model_name="facebook/bart-large-cnn"
132
+ )
133
 
134
+ # Keep track of conversation
135
+ if "conversation" not in st.session_state:
136
+ st.session_state.conversation = []
137
 
138
+ # Display existing conversation
139
  for message in st.session_state.conversation:
140
+ if message["sender"] == "user":
141
+ st.chat_message("user").write(message["text"])
142
  else:
143
+ st.chat_message("assistant").write(message["text"])
144
 
145
+ # Collect user input
146
+ if prompt := st.chat_input("How are you feeling today?"):
147
+ # Crisis detection
148
  if st.session_state.assistant.detect_crisis(prompt):
149
+ st.warning("⚠️ Potential crisis detected.")
150
+ st.markdown("**Immediate Support Resources (Canada):**")
151
  for org, phone in CRISIS_RESOURCES.items():
152
  st.markdown(f"- {org}: `{phone}`")
153
 
154
+ # Display user message
155
+ st.session_state.conversation.append({"sender": "user", "text": prompt})
156
  st.chat_message("user").write(prompt)
157
 
158
  # Generate AI response
159
  with st.chat_message("assistant"):
160
+ with st.spinner("Thinking..."):
161
  ai_response = st.session_state.assistant.generate_response(prompt)
162
  st.write(ai_response)
163
+ st.session_state.conversation.append({"sender": "assistant", "text": ai_response})
 
 
164
 
165
+ # Summarize conversation
166
+ if st.button("Generate Session Summary"):
167
+ if st.session_state.conversation:
168
+ conversation_text = " ".join(msg["text"] for msg in st.session_state.conversation)
169
  summary = st.session_state.assistant.generate_summary(conversation_text)
170
+ st.subheader("Session Summary")
171
  st.write(summary)
172
+ else:
173
+ st.info("No conversation to summarize yet.")
174
 
175
+ # Crisis Support Info in Sidebar
176
  st.sidebar.title("🆘 Crisis Support")
177
+ st.sidebar.markdown("If you're in crisis, please contact:")
178
  for org, phone in CRISIS_RESOURCES.items():
179
+ st.sidebar.markdown(f"- **{org}**: `{phone}`")
180
 
181
  if __name__ == "__main__":
182
  main()