Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,16 +2,17 @@ import os
|
|
2 |
import time
|
3 |
from datetime import datetime
|
4 |
import streamlit as st
|
5 |
-
from transformers import AutoTokenizer,
|
|
|
6 |
|
7 |
# -- SETUP --
|
8 |
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
9 |
|
10 |
@st.cache_resource
|
11 |
def load_model():
|
12 |
-
model_id = "
|
13 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
14 |
-
model =
|
15 |
return tokenizer, model
|
16 |
|
17 |
tokenizer, model = load_model()
|
@@ -20,17 +21,23 @@ if "history" not in st.session_state:
|
|
20 |
st.session_state.history = []
|
21 |
st.session_state.summary = ""
|
22 |
|
23 |
-
# --
|
24 |
-
def generate_text(prompt, max_new_tokens=150):
|
25 |
-
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
|
26 |
-
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
|
27 |
-
return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
28 |
-
|
29 |
-
# -- HIGH-RISK FILTER --
|
30 |
TRIGGER_PHRASES = ["kill myself", "end it all", "suicide", "not worth living", "can't go on"]
|
31 |
def is_high_risk(text):
|
32 |
return any(phrase in text.lower() for phrase in TRIGGER_PHRASES)
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
# -- STYLING --
|
35 |
st.markdown("""
|
36 |
<style>
|
@@ -52,16 +59,17 @@ st.markdown(f"ποΈ {datetime.now().strftime('%B %d, %Y')} | {len(st.session_s
|
|
52 |
# -- USER INPUT --
|
53 |
user_input = st.text_input("How are you feeling today?", placeholder="Start typing...")
|
54 |
|
55 |
-
# --
|
56 |
if user_input:
|
57 |
context = "\n".join([f"{s}: {m}" for s, m, _ in st.session_state.history[-4:]])
|
58 |
with st.spinner("TARS is reflecting..."):
|
59 |
-
time.sleep(
|
60 |
if is_high_risk(user_input):
|
61 |
response = "I'm really sorry you're feeling this way. You're not alone β please talk to someone you trust or a mental health professional. π"
|
62 |
else:
|
63 |
-
prompt = f"
|
64 |
-
response =
|
|
|
65 |
timestamp = datetime.now().strftime("%H:%M")
|
66 |
st.session_state.history.append(("π§ You", user_input, timestamp))
|
67 |
st.session_state.history.append(("π€ TARS", response, timestamp))
|
@@ -71,15 +79,15 @@ st.markdown("## π¨οΈ Session")
|
|
71 |
for speaker, msg, time in st.session_state.history:
|
72 |
st.markdown(f"**{speaker} [{time}]:** {msg}")
|
73 |
|
74 |
-
# -- SUMMARY
|
75 |
if st.button("π§Ύ Generate Session Summary"):
|
76 |
convo = "\n".join([f"{s}: {m}" for s, m, _ in st.session_state.history])
|
77 |
-
summary_prompt = f"Summarize this conversation in 2-3
|
78 |
try:
|
79 |
-
summary =
|
80 |
-
st.session_state.summary = summary
|
81 |
except Exception as e:
|
82 |
-
st.error("β
|
83 |
st.exception(e)
|
84 |
|
85 |
# -- DISPLAY SUMMARY --
|
|
|
2 |
import time
|
3 |
from datetime import datetime
|
4 |
import streamlit as st
|
5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
6 |
+
import torch
|
7 |
|
8 |
# -- SETUP --
|
9 |
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
10 |
|
11 |
@st.cache_resource
|
12 |
def load_model():
|
13 |
+
model_id = "tiiuae/falcon-rw-1b"
|
14 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
15 |
+
model = AutoModelForCausalLM.from_pretrained(model_id)
|
16 |
return tokenizer, model
|
17 |
|
18 |
tokenizer, model = load_model()
|
|
|
21 |
st.session_state.history = []
|
22 |
st.session_state.summary = ""
|
23 |
|
24 |
+
# -- SAFETY --
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
TRIGGER_PHRASES = ["kill myself", "end it all", "suicide", "not worth living", "can't go on"]
|
26 |
def is_high_risk(text):
|
27 |
return any(phrase in text.lower() for phrase in TRIGGER_PHRASES)
|
28 |
|
29 |
+
# -- GENERATE FUNCTION --
|
30 |
+
def generate_response(prompt, max_new_tokens=120, temperature=0.7):
|
31 |
+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
|
32 |
+
outputs = model.generate(
|
33 |
+
**inputs,
|
34 |
+
max_new_tokens=max_new_tokens,
|
35 |
+
do_sample=True,
|
36 |
+
temperature=temperature,
|
37 |
+
pad_token_id=tokenizer.eos_token_id
|
38 |
+
)
|
39 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
40 |
+
|
41 |
# -- STYLING --
|
42 |
st.markdown("""
|
43 |
<style>
|
|
|
59 |
# -- USER INPUT --
|
60 |
user_input = st.text_input("How are you feeling today?", placeholder="Start typing...")
|
61 |
|
62 |
+
# -- CHAT FLOW --
|
63 |
if user_input:
|
64 |
context = "\n".join([f"{s}: {m}" for s, m, _ in st.session_state.history[-4:]])
|
65 |
with st.spinner("TARS is reflecting..."):
|
66 |
+
time.sleep(0.8)
|
67 |
if is_high_risk(user_input):
|
68 |
response = "I'm really sorry you're feeling this way. You're not alone β please talk to someone you trust or a mental health professional. π"
|
69 |
else:
|
70 |
+
prompt = f"You are an empathetic AI. Here's the recent conversation:\n{context}\nUser: {user_input}\nAI:"
|
71 |
+
response = generate_response(prompt, max_new_tokens=100)
|
72 |
+
response = response.split("AI:")[-1].strip()
|
73 |
timestamp = datetime.now().strftime("%H:%M")
|
74 |
st.session_state.history.append(("π§ You", user_input, timestamp))
|
75 |
st.session_state.history.append(("π€ TARS", response, timestamp))
|
|
|
79 |
for speaker, msg, time in st.session_state.history:
|
80 |
st.markdown(f"**{speaker} [{time}]:** {msg}")
|
81 |
|
82 |
+
# -- SUMMARY --
|
83 |
if st.button("π§Ύ Generate Session Summary"):
|
84 |
convo = "\n".join([f"{s}: {m}" for s, m, _ in st.session_state.history])
|
85 |
+
summary_prompt = f"Summarize this conversation in 2-3 sentences:\n{convo}\nSummary:"
|
86 |
try:
|
87 |
+
summary = generate_response(summary_prompt, max_new_tokens=150, temperature=0.5)
|
88 |
+
st.session_state.summary = summary.split("Summary:")[-1].strip()
|
89 |
except Exception as e:
|
90 |
+
st.error("β Failed to generate summary.")
|
91 |
st.exception(e)
|
92 |
|
93 |
# -- DISPLAY SUMMARY --
|