Divymakesml commited on
Commit
4208dd4
Β·
verified Β·
1 Parent(s): e204e60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -32
app.py CHANGED
@@ -2,41 +2,36 @@ import os
2
  import time
3
  from datetime import datetime
4
  import streamlit as st
5
- from transformers import AutoTokenizer, AutoModelForCausalLM
6
- import torch
7
 
8
  # -- SETUP --
9
  os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
10
 
11
  @st.cache_resource
12
- def load_model():
13
- model_id = "tiiuae/falcon-rw-1b"
14
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
15
- model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
16
- return tokenizer, model
 
 
 
 
17
 
18
- tokenizer, model = load_model()
19
 
20
  if "history" not in st.session_state:
21
  st.session_state.history = []
22
  st.session_state.summary = ""
23
 
24
- # -- SAFETY --
25
  TRIGGER_PHRASES = ["kill myself", "end it all", "suicide", "not worth living", "can't go on"]
26
  def is_high_risk(text):
27
  return any(phrase in text.lower() for phrase in TRIGGER_PHRASES)
28
 
29
- # -- GENERATE FUNCTION --
30
- def generate_response(prompt, max_new_tokens=120, temperature=0.7):
31
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
32
- outputs = model.generate(
33
- **inputs,
34
- max_new_tokens=max_new_tokens,
35
- do_sample=True,
36
- temperature=temperature,
37
- pad_token_id=tokenizer.eos_token_id
38
- )
39
- return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
40
 
41
  # -- STYLING --
42
  st.markdown("""
@@ -59,22 +54,21 @@ st.markdown(f"πŸ—“οΈ {datetime.now().strftime('%B %d, %Y')} | {len(st.session_s
59
  # -- USER INPUT --
60
  user_input = st.text_input("How are you feeling today?", placeholder="Start typing...")
61
 
62
- # -- CHAT FLOW --
63
  if user_input:
64
  context = "\n".join([f"{s}: {m}" for s, m, _ in st.session_state.history[-4:]])
65
  with st.spinner("TARS is reflecting..."):
66
- time.sleep(0.8)
67
  if is_high_risk(user_input):
68
  response = "I'm really sorry you're feeling this way. You're not alone β€” please talk to someone you trust or a mental health professional. πŸ’™"
69
  else:
70
- prompt = f"You are an empathetic AI. Here's the recent conversation:\n{context}\nUser: {user_input}\nAI:"
71
- response = generate_response(prompt, max_new_tokens=100)
72
- response = response.split("AI:")[-1].strip()
73
  timestamp = datetime.now().strftime("%H:%M")
74
  st.session_state.history.append(("🧍 You", user_input, timestamp))
75
  st.session_state.history.append(("πŸ€– TARS", response, timestamp))
76
 
77
- # -- DISPLAY CHAT --
78
  st.markdown("## πŸ—¨οΈ Session")
79
  for speaker, msg, time in st.session_state.history:
80
  st.markdown(f"**{speaker} [{time}]:** {msg}")
@@ -82,15 +76,14 @@ for speaker, msg, time in st.session_state.history:
82
  # -- SUMMARY --
83
  if st.button("🧾 Generate Session Summary"):
84
  convo = "\n".join([f"{s}: {m}" for s, m, _ in st.session_state.history])
85
- summary_prompt = f"Summarize this conversation in 2-3 sentences:\n{convo}\nSummary:"
86
  try:
87
- summary = generate_response(summary_prompt, max_new_tokens=150, temperature=0.5)
88
- st.session_state.summary = summary.split("Summary:")[-1].strip()
89
  except Exception as e:
90
- st.error("❌ Failed to generate summary.")
91
  st.exception(e)
92
 
93
- # -- DISPLAY SUMMARY --
94
  if st.session_state.summary:
95
  st.markdown("### 🧠 Session Note")
96
  st.markdown(st.session_state.summary)
@@ -98,4 +91,4 @@ if st.session_state.summary:
98
 
99
  # -- FOOTER --
100
  st.markdown("---")
101
- st.caption("TARS is not a therapist but a quiet assistant that reflects with you.")
 
2
  import time
3
  from datetime import datetime
4
  import streamlit as st
5
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
 
6
 
7
  # -- SETUP --
8
  os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
9
 
10
  @st.cache_resource
11
+ def load_pipeline():
12
+ model_id = "tiiuae/falcon-7b-instruct"
13
+ pipe = pipeline(
14
+ "text-generation",
15
+ model=AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True),
16
+ tokenizer=AutoTokenizer.from_pretrained(model_id, trust_remote_code=True),
17
+ device_map="auto"
18
+ )
19
+ return pipe
20
 
21
+ generator = load_pipeline()
22
 
23
  if "history" not in st.session_state:
24
  st.session_state.history = []
25
  st.session_state.summary = ""
26
 
27
+ # -- UTILS --
28
  TRIGGER_PHRASES = ["kill myself", "end it all", "suicide", "not worth living", "can't go on"]
29
  def is_high_risk(text):
30
  return any(phrase in text.lower() for phrase in TRIGGER_PHRASES)
31
 
32
+ def get_reply(prompt, max_new_tokens=150, temperature=0.7):
33
+ out = generator(prompt, max_new_tokens=max_new_tokens, temperature=temperature)[0]["generated_text"]
34
+ return out.split("AI:")[-1].strip() if "AI:" in out else out.strip()
 
 
 
 
 
 
 
 
35
 
36
  # -- STYLING --
37
  st.markdown("""
 
54
  # -- USER INPUT --
55
  user_input = st.text_input("How are you feeling today?", placeholder="Start typing...")
56
 
57
+ # -- MAIN FLOW --
58
  if user_input:
59
  context = "\n".join([f"{s}: {m}" for s, m, _ in st.session_state.history[-4:]])
60
  with st.spinner("TARS is reflecting..."):
61
+ time.sleep(1)
62
  if is_high_risk(user_input):
63
  response = "I'm really sorry you're feeling this way. You're not alone β€” please talk to someone you trust or a mental health professional. πŸ’™"
64
  else:
65
+ prompt = f"You are a kind and calm AI assistant.\n{context}\nUser: {user_input}\nAI:"
66
+ response = get_reply(prompt, max_new_tokens=150)
 
67
  timestamp = datetime.now().strftime("%H:%M")
68
  st.session_state.history.append(("🧍 You", user_input, timestamp))
69
  st.session_state.history.append(("πŸ€– TARS", response, timestamp))
70
 
71
+ # -- CHAT DISPLAY --
72
  st.markdown("## πŸ—¨οΈ Session")
73
  for speaker, msg, time in st.session_state.history:
74
  st.markdown(f"**{speaker} [{time}]:** {msg}")
 
76
  # -- SUMMARY --
77
  if st.button("🧾 Generate Session Summary"):
78
  convo = "\n".join([f"{s}: {m}" for s, m, _ in st.session_state.history])
79
+ prompt = f"Summarize this conversation in 3 reflective sentences:\n{convo}\nSummary:"
80
  try:
81
+ summary = get_reply(prompt, max_new_tokens=200, temperature=0.5)
82
+ st.session_state.summary = summary
83
  except Exception as e:
84
+ st.error("Summary generation failed.")
85
  st.exception(e)
86
 
 
87
  if st.session_state.summary:
88
  st.markdown("### 🧠 Session Note")
89
  st.markdown(st.session_state.summary)
 
91
 
92
  # -- FOOTER --
93
  st.markdown("---")
94
+ st.caption("TARS is not a therapist, but a quiet assistant that reflects with you.")