File size: 3,689 Bytes
f2169e5
e988eb0
 
 
5b5be76
 
b5b9af8
e988eb0
 
b5b9af8
 
25a7813
5b5be76
e204e60
 
25a7813
b5b9af8
25a7813
b5b9af8
95b2ec1
 
e988eb0
 
5b5be76
25a7813
 
 
 
5b5be76
 
 
 
 
 
 
 
 
 
 
 
e988eb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95b2ec1
25a7813
e988eb0
b5b9af8
5b5be76
95b2ec1
e988eb0
 
5b5be76
e988eb0
 
 
5b5be76
 
 
25a7813
 
 
e988eb0
25a7813
e988eb0
 
 
 
5b5be76
e988eb0
 
5b5be76
e988eb0
5b5be76
 
e988eb0
5b5be76
e988eb0
 
25a7813
e988eb0
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import time
from datetime import datetime
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# -- SETUP --
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

@st.cache_resource
def load_model():
    model_id = "tiiuae/falcon-rw-1b"
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
    return tokenizer, model

tokenizer, model = load_model()

if "history" not in st.session_state:
    st.session_state.history = []
    st.session_state.summary = ""

# -- SAFETY --
TRIGGER_PHRASES = ["kill myself", "end it all", "suicide", "not worth living", "can't go on"]
def is_high_risk(text):
    return any(phrase in text.lower() for phrase in TRIGGER_PHRASES)

# -- GENERATE FUNCTION --
def generate_response(prompt, max_new_tokens=120, temperature=0.7):
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
        pad_token_id=tokenizer.eos_token_id
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()

# -- STYLING --
st.markdown("""
    <style>
    body {
        background-color: #111827;
        color: #f3f4f6;
    }
    .stTextInput > div > div > input {
        color: #f3f4f6;
    }
    </style>
""", unsafe_allow_html=True)

# -- HEADER --
st.title("🧠 TARS.help")
st.markdown("### A minimal AI that listens, reflects, and replies.")
st.markdown(f"πŸ—“οΈ {datetime.now().strftime('%B %d, %Y')} | {len(st.session_state.history)//2} exchanges")

# -- USER INPUT --
user_input = st.text_input("How are you feeling today?", placeholder="Start typing...")

# -- CHAT FLOW --
if user_input:
    context = "\n".join([f"{s}: {m}" for s, m, _ in st.session_state.history[-4:]])
    with st.spinner("TARS is reflecting..."):
        time.sleep(0.8)
        if is_high_risk(user_input):
            response = "I'm really sorry you're feeling this way. You're not alone β€” please talk to someone you trust or a mental health professional. πŸ’™"
        else:
            prompt = f"You are an empathetic AI. Here's the recent conversation:\n{context}\nUser: {user_input}\nAI:"
            response = generate_response(prompt, max_new_tokens=100)
            response = response.split("AI:")[-1].strip()
    timestamp = datetime.now().strftime("%H:%M")
    st.session_state.history.append(("🧍 You", user_input, timestamp))
    st.session_state.history.append(("πŸ€– TARS", response, timestamp))

# -- DISPLAY CHAT --
st.markdown("## πŸ—¨οΈ Session")
for speaker, msg, time in st.session_state.history:
    st.markdown(f"**{speaker} [{time}]:** {msg}")

# -- SUMMARY --
if st.button("🧾 Generate Session Summary"):
    convo = "\n".join([f"{s}: {m}" for s, m, _ in st.session_state.history])
    summary_prompt = f"Summarize this conversation in 2-3 sentences:\n{convo}\nSummary:"
    try:
        summary = generate_response(summary_prompt, max_new_tokens=150, temperature=0.5)
        st.session_state.summary = summary.split("Summary:")[-1].strip()
    except Exception as e:
        st.error("❌ Failed to generate summary.")
        st.exception(e)

# -- DISPLAY SUMMARY --
if st.session_state.summary:
    st.markdown("### 🧠 Session Note")
    st.markdown(st.session_state.summary)
    st.download_button("πŸ“₯ Download Summary", st.session_state.summary, file_name="tars_session.txt")

# -- FOOTER --
st.markdown("---")
st.caption("TARS is not a therapist but a quiet assistant that reflects with you.")