Spaces:
Sleeping
Sleeping
import os | |
import time | |
from datetime import datetime | |
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
# -- SETUP -- | |
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" | |
def load_model(): | |
model_id = "tiiuae/falcon-rw-1b" | |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True) | |
return tokenizer, model | |
tokenizer, model = load_model() | |
if "history" not in st.session_state: | |
st.session_state.history = [] | |
st.session_state.summary = "" | |
# -- SAFETY -- | |
TRIGGER_PHRASES = ["kill myself", "end it all", "suicide", "not worth living", "can't go on"] | |
def is_high_risk(text): | |
return any(phrase in text.lower() for phrase in TRIGGER_PHRASES) | |
# -- GENERATE FUNCTION -- | |
def generate_response(prompt, max_new_tokens=120, temperature=0.7): | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True) | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
temperature=temperature, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True).strip() | |
# -- STYLING -- | |
st.markdown(""" | |
<style> | |
body { | |
background-color: #111827; | |
color: #f3f4f6; | |
} | |
.stTextInput > div > div > input { | |
color: #f3f4f6; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# -- HEADER -- | |
st.title("π§ TARS.help") | |
st.markdown("### A minimal AI that listens, reflects, and replies.") | |
st.markdown(f"ποΈ {datetime.now().strftime('%B %d, %Y')} | {len(st.session_state.history)//2} exchanges") | |
# -- USER INPUT -- | |
user_input = st.text_input("How are you feeling today?", placeholder="Start typing...") | |
# -- CHAT FLOW -- | |
if user_input: | |
context = "\n".join([f"{s}: {m}" for s, m, _ in st.session_state.history[-4:]]) | |
with st.spinner("TARS is reflecting..."): | |
time.sleep(0.8) | |
if is_high_risk(user_input): | |
response = "I'm really sorry you're feeling this way. You're not alone β please talk to someone you trust or a mental health professional. π" | |
else: | |
prompt = f"You are an empathetic AI. Here's the recent conversation:\n{context}\nUser: {user_input}\nAI:" | |
response = generate_response(prompt, max_new_tokens=100) | |
response = response.split("AI:")[-1].strip() | |
timestamp = datetime.now().strftime("%H:%M") | |
st.session_state.history.append(("π§ You", user_input, timestamp)) | |
st.session_state.history.append(("π€ TARS", response, timestamp)) | |
# -- DISPLAY CHAT -- | |
st.markdown("## π¨οΈ Session") | |
for speaker, msg, time in st.session_state.history: | |
st.markdown(f"**{speaker} [{time}]:** {msg}") | |
# -- SUMMARY -- | |
if st.button("π§Ύ Generate Session Summary"): | |
convo = "\n".join([f"{s}: {m}" for s, m, _ in st.session_state.history]) | |
summary_prompt = f"Summarize this conversation in 2-3 sentences:\n{convo}\nSummary:" | |
try: | |
summary = generate_response(summary_prompt, max_new_tokens=150, temperature=0.5) | |
st.session_state.summary = summary.split("Summary:")[-1].strip() | |
except Exception as e: | |
st.error("β Failed to generate summary.") | |
st.exception(e) | |
# -- DISPLAY SUMMARY -- | |
if st.session_state.summary: | |
st.markdown("### π§ Session Note") | |
st.markdown(st.session_state.summary) | |
st.download_button("π₯ Download Summary", st.session_state.summary, file_name="tars_session.txt") | |
# -- FOOTER -- | |
st.markdown("---") | |
st.caption("TARS is not a therapist but a quiet assistant that reflects with you.") | |