Divymakesml commited on
Commit
e988eb0
Β·
verified Β·
1 Parent(s): a8706c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -97
app.py CHANGED
@@ -1,112 +1,93 @@
1
- import streamlit as st
2
  import os
3
- os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
4
- os.system("pip install tensorflow-cpu==2.11.0")
5
- import tensorflow
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
- from tensorflow.keras.models import load_model
8
- from tensorflow.keras.preprocessing.text import Tokenizer
9
- from tensorflow.keras.preprocessing.sequence import pad_sequences
10
-
11
- # Load tokenizer used in training
12
- tokenizer = Tokenizer(num_words=10000)
13
- # You must re-train or load tokenizer from a JSON if you saved it!
14
- tokenizer.fit_on_texts(["dummy"]) # Temporary; replace with loaded tokenizer
15
-
16
- # Preprocess text for models
17
- def preprocess(text):
18
- sequence = tokenizer.texts_to_sequences([text])
19
- return pad_sequences(sequence, maxlen=100)
20
-
21
- # Load Keras models
22
- model1 = load_model("model1.h5") # Suicide risk
23
- model2 = load_model("best_model (2).keras") # Diagnosis classifier
24
-
25
- # Model prediction wrappers
26
- def model1_predict(text):
27
- pred = model1.predict(preprocess(text))[0][0]
28
- return int(pred > 0.5)
29
-
30
- def model2_predict(text):
31
- pred = model2.predict(preprocess(text))[0]
32
- return int(np.argmax(pred))
33
 
34
- diagnosis_labels = {
35
- 1: "Anxiety",
36
- 2: "Depression",
37
- 3: "Bipolar disorder",
38
- 4: "PTSD",
39
- 5: "OCD",
40
- 6: "ADHD",
41
- 7: "General emotional distress"
42
- }
43
 
44
  @st.cache_resource
45
- def load_llm():
46
- model_id = "tiiuae/falcon-7b-instruct"
47
  tokenizer = AutoTokenizer.from_pretrained(model_id)
48
- model = AutoModelForCausalLM.from_pretrained(
49
- model_id,
50
- device_map="auto",
51
- trust_remote_code=True,
52
- torch_dtype="auto"
53
- )
54
- return pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
55
 
56
- generator = load_llm()
57
 
58
- # Session memory
59
  if "history" not in st.session_state:
60
  st.session_state.history = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- def therapist_pipeline(user_input):
63
- st.session_state.history.append(f"User: {user_input}")
64
- risk = model1_predict(user_input)
65
-
66
- if risk == 1:
67
- response = (
68
- "I'm really sorry you're feeling this way. You're not alone β€” please talk to someone you trust "
69
- "or a professional. I'm here to listen, but it's important to get real support too. Please contact 9-8-8 if you need immediate support. I hope you get better. πŸ’™"
70
- )
71
- else:
72
- diagnosis_code = model2_predict(user_input)
73
- diagnosis = diagnosis_labels.get(diagnosis_code, "General emotional distress")
74
-
75
- prompt = f"""You are an empathetic AI therapist. The user has been diagnosed with {diagnosis}. Respond supportively.
76
-
77
- User: {user_input}
78
- AI:"""
79
-
80
- response = generator(prompt, max_new_tokens=150, temperature=0.7)[0]["generated_text"]
81
- response = response.split("AI:")[-1].strip()
82
-
83
- st.session_state.history.append(f"AI: {response}")
84
- return response
85
-
86
- def summarize_session():
87
- session_text = "\n".join(st.session_state.history)
88
- prompt = f"""Summarize the emotional state of the user based on the following conversation. Include emotional cues and possible diagnoses. Write it like a therapist note.
89
 
90
- Conversation:
91
- {session_text}
92
 
93
- Summary:"""
94
- summary = generator(prompt, max_new_tokens=250, temperature=0.5)[0]["generated_text"]
95
- return summary.split("Summary:")[-1].strip()
96
 
97
- # Streamlit UI
98
- st.title("🧠 TARS.help")
99
- user_input = st.text_input("How are you feeling today?")
 
 
100
 
 
101
  if user_input:
102
- response = therapist_pipeline(user_input)
103
- st.markdown(f"**AI Therapist:** {response}")
104
-
105
- if st.button("🧾 Generate Therapist Summary"):
106
- st.markdown("### 🧠 Session Summary")
107
- st.markdown(summarize_session())
108
-
109
- # Show history
110
- for i in range(0, len(st.session_state.history), 2):
111
- st.markdown(f"**You:** {st.session_state.history[i][6:]}")
112
- st.markdown(f"**AI:** {st.session_state.history[i+1][4:]}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import time
3
+ from datetime import datetime
4
+ import streamlit as st
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ # -- SETUP --
8
+ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
 
 
 
 
 
 
 
9
 
10
  @st.cache_resource
11
+ def load_respondent():
12
+ model_id = "google/flan-t5-small"
13
  tokenizer = AutoTokenizer.from_pretrained(model_id)
14
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
15
+ return pipeline("text2text-generation", model=model, tokenizer=tokenizer)
 
 
 
 
 
16
 
17
+ generator = load_respondent()
18
 
 
19
  if "history" not in st.session_state:
20
  st.session_state.history = []
21
+ st.session_state.summary = ""
22
+
23
+ # -- STYLING --
24
+ st.markdown("""
25
+ <style>
26
+ body {
27
+ background-color: #111827;
28
+ color: #f3f4f6;
29
+ }
30
+ .stTextInput > div > div > input {
31
+ color: #f3f4f6;
32
+ }
33
+ </style>
34
+ """, unsafe_allow_html=True)
35
+
36
+ # -- HEADER --
37
+ st.title("🧠 TARS.help")
38
+ st.markdown("### A minimal AI that listens, reflects, and replies.")
39
+ st.markdown(f"πŸ—“οΈ {datetime.now().strftime('%B %d, %Y')} | {len(st.session_state.history)//2} exchanges")
40
 
41
+ # -- SAFETY FILTER --
42
+ TRIGGER_PHRASES = ["kill myself", "end it all", "suicide", "not worth living", "can't go on"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ def is_high_risk(text):
45
+ return any(phrase in text.lower() for phrase in TRIGGER_PHRASES)
46
 
47
+ # -- INPUT --
48
+ user_input = st.text_input("How are you feeling today?", placeholder="Start typing...")
 
49
 
50
+ # -- REPLY FUNCTION --
51
+ def generate_reply(context):
52
+ prompt = f"Respond empathetically to this conversation:\n{context}"
53
+ result = generator(prompt, max_new_tokens=80, temperature=0.7)[0]["generated_text"]
54
+ return result.strip()
55
 
56
+ # -- CONVERSATION FLOW --
57
  if user_input:
58
+ context = "\n".join([f"{s}: {m}" for s, m, _ in st.session_state.history[-4:]])
59
+ with st.spinner("TARS is reflecting..."):
60
+ time.sleep(0.5)
61
+ if is_high_risk(user_input):
62
+ response = "I'm really sorry you're feeling this way. You're not alone β€” please talk to someone you trust or a mental health professional. πŸ’™"
63
+ else:
64
+ full_context = context + f"\nUser: {user_input}"
65
+ response = generate_reply(full_context)
66
+ timestamp = datetime.now().strftime("%H:%M")
67
+ st.session_state.history.append(("🧍 You", user_input, timestamp))
68
+ st.session_state.history.append(("πŸ€– TARS", response, timestamp))
69
+
70
+ # -- DISPLAY HISTORY --
71
+ st.markdown("## πŸ—¨οΈ Session")
72
+ for speaker, msg, time in st.session_state.history:
73
+ st.markdown(f"**{speaker} [{time}]:** {msg}")
74
+
75
+ # -- SESSION SUMMARY --
76
+ if st.button("🧾 Generate Session Summary"):
77
+ convo = "\n".join([f"{s}: {m}" for s, m, _ in st.session_state.history])
78
+ prompt = f"Summarize this conversation in 2-3 sentences:\n{convo}"
79
+ try:
80
+ output = generator(prompt, max_new_tokens=120, temperature=0.5)[0]['generated_text']
81
+ st.session_state.summary = output.strip()
82
+ except Exception as e:
83
+ st.error("❌ Summary generation failed.")
84
+ st.exception(e)
85
+
86
+ if st.session_state.summary:
87
+ st.markdown("### 🧠 Session Note")
88
+ st.markdown(st.session_state.summary)
89
+ st.download_button("πŸ“₯ Download Summary", st.session_state.summary, file_name="tars_session.txt")
90
+
91
+ # -- FOOTER --
92
+ st.markdown("---")
93
+ st.caption("TARS is not a therapist but a quiet assistant that reflects with you.")