Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,112 +1,93 @@
|
|
1 |
-
import streamlit as st
|
2 |
import os
|
3 |
-
|
4 |
-
|
5 |
-
import
|
6 |
-
from transformers import AutoTokenizer,
|
7 |
-
from tensorflow.keras.models import load_model
|
8 |
-
from tensorflow.keras.preprocessing.text import Tokenizer
|
9 |
-
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
10 |
-
|
11 |
-
# Load tokenizer used in training
|
12 |
-
tokenizer = Tokenizer(num_words=10000)
|
13 |
-
# You must re-train or load tokenizer from a JSON if you saved it!
|
14 |
-
tokenizer.fit_on_texts(["dummy"]) # Temporary; replace with loaded tokenizer
|
15 |
-
|
16 |
-
# Preprocess text for models
|
17 |
-
def preprocess(text):
|
18 |
-
sequence = tokenizer.texts_to_sequences([text])
|
19 |
-
return pad_sequences(sequence, maxlen=100)
|
20 |
-
|
21 |
-
# Load Keras models
|
22 |
-
model1 = load_model("model1.h5") # Suicide risk
|
23 |
-
model2 = load_model("best_model (2).keras") # Diagnosis classifier
|
24 |
-
|
25 |
-
# Model prediction wrappers
|
26 |
-
def model1_predict(text):
|
27 |
-
pred = model1.predict(preprocess(text))[0][0]
|
28 |
-
return int(pred > 0.5)
|
29 |
-
|
30 |
-
def model2_predict(text):
|
31 |
-
pred = model2.predict(preprocess(text))[0]
|
32 |
-
return int(np.argmax(pred))
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
2: "Depression",
|
37 |
-
3: "Bipolar disorder",
|
38 |
-
4: "PTSD",
|
39 |
-
5: "OCD",
|
40 |
-
6: "ADHD",
|
41 |
-
7: "General emotional distress"
|
42 |
-
}
|
43 |
|
44 |
@st.cache_resource
|
45 |
-
def
|
46 |
-
model_id = "
|
47 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
48 |
-
model =
|
49 |
-
|
50 |
-
device_map="auto",
|
51 |
-
trust_remote_code=True,
|
52 |
-
torch_dtype="auto"
|
53 |
-
)
|
54 |
-
return pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
|
55 |
|
56 |
-
generator =
|
57 |
|
58 |
-
# Session memory
|
59 |
if "history" not in st.session_state:
|
60 |
st.session_state.history = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
risk = model1_predict(user_input)
|
65 |
-
|
66 |
-
if risk == 1:
|
67 |
-
response = (
|
68 |
-
"I'm really sorry you're feeling this way. You're not alone β please talk to someone you trust "
|
69 |
-
"or a professional. I'm here to listen, but it's important to get real support too. Please contact 9-8-8 if you need immediate support. I hope you get better. π"
|
70 |
-
)
|
71 |
-
else:
|
72 |
-
diagnosis_code = model2_predict(user_input)
|
73 |
-
diagnosis = diagnosis_labels.get(diagnosis_code, "General emotional distress")
|
74 |
-
|
75 |
-
prompt = f"""You are an empathetic AI therapist. The user has been diagnosed with {diagnosis}. Respond supportively.
|
76 |
-
|
77 |
-
User: {user_input}
|
78 |
-
AI:"""
|
79 |
-
|
80 |
-
response = generator(prompt, max_new_tokens=150, temperature=0.7)[0]["generated_text"]
|
81 |
-
response = response.split("AI:")[-1].strip()
|
82 |
-
|
83 |
-
st.session_state.history.append(f"AI: {response}")
|
84 |
-
return response
|
85 |
-
|
86 |
-
def summarize_session():
|
87 |
-
session_text = "\n".join(st.session_state.history)
|
88 |
-
prompt = f"""Summarize the emotional state of the user based on the following conversation. Include emotional cues and possible diagnoses. Write it like a therapist note.
|
89 |
|
90 |
-
|
91 |
-
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
return summary.split("Summary:")[-1].strip()
|
96 |
|
97 |
-
#
|
98 |
-
|
99 |
-
|
|
|
|
|
100 |
|
|
|
101 |
if user_input:
|
102 |
-
|
103 |
-
st.
|
104 |
-
|
105 |
-
if
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
import time
|
3 |
+
from datetime import datetime
|
4 |
+
import streamlit as st
|
5 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
+
# -- SETUP --
|
8 |
+
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
@st.cache_resource
|
11 |
+
def load_respondent():
|
12 |
+
model_id = "google/flan-t5-small"
|
13 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
14 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
15 |
+
return pipeline("text2text-generation", model=model, tokenizer=tokenizer)
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
generator = load_respondent()
|
18 |
|
|
|
19 |
if "history" not in st.session_state:
|
20 |
st.session_state.history = []
|
21 |
+
st.session_state.summary = ""
|
22 |
+
|
23 |
+
# -- STYLING --
|
24 |
+
st.markdown("""
|
25 |
+
<style>
|
26 |
+
body {
|
27 |
+
background-color: #111827;
|
28 |
+
color: #f3f4f6;
|
29 |
+
}
|
30 |
+
.stTextInput > div > div > input {
|
31 |
+
color: #f3f4f6;
|
32 |
+
}
|
33 |
+
</style>
|
34 |
+
""", unsafe_allow_html=True)
|
35 |
+
|
36 |
+
# -- HEADER --
|
37 |
+
st.title("π§ TARS.help")
|
38 |
+
st.markdown("### A minimal AI that listens, reflects, and replies.")
|
39 |
+
st.markdown(f"ποΈ {datetime.now().strftime('%B %d, %Y')} | {len(st.session_state.history)//2} exchanges")
|
40 |
|
41 |
+
# -- SAFETY FILTER --
|
42 |
+
TRIGGER_PHRASES = ["kill myself", "end it all", "suicide", "not worth living", "can't go on"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
+
def is_high_risk(text):
|
45 |
+
return any(phrase in text.lower() for phrase in TRIGGER_PHRASES)
|
46 |
|
47 |
+
# -- INPUT --
|
48 |
+
user_input = st.text_input("How are you feeling today?", placeholder="Start typing...")
|
|
|
49 |
|
50 |
+
# -- REPLY FUNCTION --
|
51 |
+
def generate_reply(context):
|
52 |
+
prompt = f"Respond empathetically to this conversation:\n{context}"
|
53 |
+
result = generator(prompt, max_new_tokens=80, temperature=0.7)[0]["generated_text"]
|
54 |
+
return result.strip()
|
55 |
|
56 |
+
# -- CONVERSATION FLOW --
|
57 |
if user_input:
|
58 |
+
context = "\n".join([f"{s}: {m}" for s, m, _ in st.session_state.history[-4:]])
|
59 |
+
with st.spinner("TARS is reflecting..."):
|
60 |
+
time.sleep(0.5)
|
61 |
+
if is_high_risk(user_input):
|
62 |
+
response = "I'm really sorry you're feeling this way. You're not alone β please talk to someone you trust or a mental health professional. π"
|
63 |
+
else:
|
64 |
+
full_context = context + f"\nUser: {user_input}"
|
65 |
+
response = generate_reply(full_context)
|
66 |
+
timestamp = datetime.now().strftime("%H:%M")
|
67 |
+
st.session_state.history.append(("π§ You", user_input, timestamp))
|
68 |
+
st.session_state.history.append(("π€ TARS", response, timestamp))
|
69 |
+
|
70 |
+
# -- DISPLAY HISTORY --
|
71 |
+
st.markdown("## π¨οΈ Session")
|
72 |
+
for speaker, msg, time in st.session_state.history:
|
73 |
+
st.markdown(f"**{speaker} [{time}]:** {msg}")
|
74 |
+
|
75 |
+
# -- SESSION SUMMARY --
|
76 |
+
if st.button("π§Ύ Generate Session Summary"):
|
77 |
+
convo = "\n".join([f"{s}: {m}" for s, m, _ in st.session_state.history])
|
78 |
+
prompt = f"Summarize this conversation in 2-3 sentences:\n{convo}"
|
79 |
+
try:
|
80 |
+
output = generator(prompt, max_new_tokens=120, temperature=0.5)[0]['generated_text']
|
81 |
+
st.session_state.summary = output.strip()
|
82 |
+
except Exception as e:
|
83 |
+
st.error("β Summary generation failed.")
|
84 |
+
st.exception(e)
|
85 |
+
|
86 |
+
if st.session_state.summary:
|
87 |
+
st.markdown("### π§ Session Note")
|
88 |
+
st.markdown(st.session_state.summary)
|
89 |
+
st.download_button("π₯ Download Summary", st.session_state.summary, file_name="tars_session.txt")
|
90 |
+
|
91 |
+
# -- FOOTER --
|
92 |
+
st.markdown("---")
|
93 |
+
st.caption("TARS is not a therapist but a quiet assistant that reflects with you.")
|