import os import time from datetime import datetime import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM import torch # -- SETUP -- os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" @st.cache_resource def load_model(): model_id = "tiiuae/falcon-rw-1b" tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True) return tokenizer, model tokenizer, model = load_model() if "history" not in st.session_state: st.session_state.history = [] st.session_state.summary = "" # -- SAFETY -- TRIGGER_PHRASES = ["kill myself", "end it all", "suicide", "not worth living", "can't go on"] def is_high_risk(text): return any(phrase in text.lower() for phrase in TRIGGER_PHRASES) # -- GENERATE FUNCTION -- def generate_response(prompt, max_new_tokens=120, temperature=0.7): inputs = tokenizer(prompt, return_tensors="pt", truncation=True) outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, pad_token_id=tokenizer.eos_token_id ) return tokenizer.decode(outputs[0], skip_special_tokens=True).strip() # -- STYLING -- st.markdown(""" """, unsafe_allow_html=True) # -- HEADER -- st.title("๐Ÿง  TARS.help") st.markdown("### A minimal AI that listens, reflects, and replies.") st.markdown(f"๐Ÿ—“๏ธ {datetime.now().strftime('%B %d, %Y')} | {len(st.session_state.history)//2} exchanges") # -- USER INPUT -- user_input = st.text_input("How are you feeling today?", placeholder="Start typing...") # -- CHAT FLOW -- if user_input: context = "\n".join([f"{s}: {m}" for s, m, _ in st.session_state.history[-4:]]) with st.spinner("TARS is reflecting..."): time.sleep(0.8) if is_high_risk(user_input): response = "I'm really sorry you're feeling this way. You're not alone โ€” please talk to someone you trust or a mental health professional. ๐Ÿ’™" else: prompt = f"You are an empathetic AI. Here's the recent conversation:\n{context}\nUser: {user_input}\nAI:" response = generate_response(prompt, max_new_tokens=100) response = response.split("AI:")[-1].strip() timestamp = datetime.now().strftime("%H:%M") st.session_state.history.append(("๐Ÿง You", user_input, timestamp)) st.session_state.history.append(("๐Ÿค– TARS", response, timestamp)) # -- DISPLAY CHAT -- st.markdown("## ๐Ÿ—จ๏ธ Session") for speaker, msg, time in st.session_state.history: st.markdown(f"**{speaker} [{time}]:** {msg}") # -- SUMMARY -- if st.button("๐Ÿงพ Generate Session Summary"): convo = "\n".join([f"{s}: {m}" for s, m, _ in st.session_state.history]) summary_prompt = f"Summarize this conversation in 2-3 sentences:\n{convo}\nSummary:" try: summary = generate_response(summary_prompt, max_new_tokens=150, temperature=0.5) st.session_state.summary = summary.split("Summary:")[-1].strip() except Exception as e: st.error("โŒ Failed to generate summary.") st.exception(e) # -- DISPLAY SUMMARY -- if st.session_state.summary: st.markdown("### ๐Ÿง  Session Note") st.markdown(st.session_state.summary) st.download_button("๐Ÿ“ฅ Download Summary", st.session_state.summary, file_name="tars_session.txt") # -- FOOTER -- st.markdown("---") st.caption("TARS is not a therapist but a quiet assistant that reflects with you.")