Divymakesml commited on
Commit
5b5be76
Β·
verified Β·
1 Parent(s): 25a7813

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -19
app.py CHANGED
@@ -2,16 +2,17 @@ import os
2
  import time
3
  from datetime import datetime
4
  import streamlit as st
5
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
6
 
7
  # -- SETUP --
8
  os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
9
 
10
  @st.cache_resource
11
  def load_model():
12
- model_id = "google/flan-t5-base"
13
  tokenizer = AutoTokenizer.from_pretrained(model_id)
14
- model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
15
  return tokenizer, model
16
 
17
  tokenizer, model = load_model()
@@ -20,17 +21,23 @@ if "history" not in st.session_state:
20
  st.session_state.history = []
21
  st.session_state.summary = ""
22
 
23
- # -- TEXT GENERATION FUNCTION --
24
- def generate_text(prompt, max_new_tokens=150):
25
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
26
- outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
27
- return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
28
-
29
- # -- HIGH-RISK FILTER --
30
  TRIGGER_PHRASES = ["kill myself", "end it all", "suicide", "not worth living", "can't go on"]
31
  def is_high_risk(text):
32
  return any(phrase in text.lower() for phrase in TRIGGER_PHRASES)
33
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # -- STYLING --
35
  st.markdown("""
36
  <style>
@@ -52,16 +59,17 @@ st.markdown(f"πŸ—“οΈ {datetime.now().strftime('%B %d, %Y')} | {len(st.session_s
52
  # -- USER INPUT --
53
  user_input = st.text_input("How are you feeling today?", placeholder="Start typing...")
54
 
55
- # -- MAIN CHAT LOGIC --
56
  if user_input:
57
  context = "\n".join([f"{s}: {m}" for s, m, _ in st.session_state.history[-4:]])
58
  with st.spinner("TARS is reflecting..."):
59
- time.sleep(1.2)
60
  if is_high_risk(user_input):
61
  response = "I'm really sorry you're feeling this way. You're not alone β€” please talk to someone you trust or a mental health professional. πŸ’™"
62
  else:
63
- prompt = f"Respond with empathy:\n{context}\nUser: {user_input}"
64
- response = generate_text(prompt, max_new_tokens=100)
 
65
  timestamp = datetime.now().strftime("%H:%M")
66
  st.session_state.history.append(("🧍 You", user_input, timestamp))
67
  st.session_state.history.append(("πŸ€– TARS", response, timestamp))
@@ -71,15 +79,15 @@ st.markdown("## πŸ—¨οΈ Session")
71
  for speaker, msg, time in st.session_state.history:
72
  st.markdown(f"**{speaker} [{time}]:** {msg}")
73
 
74
- # -- SUMMARY GENERATION --
75
  if st.button("🧾 Generate Session Summary"):
76
  convo = "\n".join([f"{s}: {m}" for s, m, _ in st.session_state.history])
77
- summary_prompt = f"Summarize this conversation in 2-3 thoughtful sentences:\n{convo}"
78
  try:
79
- summary = generate_text(summary_prompt, max_new_tokens=150)
80
- st.session_state.summary = summary
81
  except Exception as e:
82
- st.error("❌ Summary generation failed.")
83
  st.exception(e)
84
 
85
  # -- DISPLAY SUMMARY --
 
2
  import time
3
  from datetime import datetime
4
  import streamlit as st
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ import torch
7
 
8
  # -- SETUP --
9
  os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
10
 
11
  @st.cache_resource
12
  def load_model():
13
+ model_id = "tiiuae/falcon-rw-1b"
14
  tokenizer = AutoTokenizer.from_pretrained(model_id)
15
+ model = AutoModelForCausalLM.from_pretrained(model_id)
16
  return tokenizer, model
17
 
18
  tokenizer, model = load_model()
 
21
  st.session_state.history = []
22
  st.session_state.summary = ""
23
 
24
+ # -- SAFETY --
 
 
 
 
 
 
25
  TRIGGER_PHRASES = ["kill myself", "end it all", "suicide", "not worth living", "can't go on"]
26
  def is_high_risk(text):
27
  return any(phrase in text.lower() for phrase in TRIGGER_PHRASES)
28
 
29
+ # -- GENERATE FUNCTION --
30
+ def generate_response(prompt, max_new_tokens=120, temperature=0.7):
31
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
32
+ outputs = model.generate(
33
+ **inputs,
34
+ max_new_tokens=max_new_tokens,
35
+ do_sample=True,
36
+ temperature=temperature,
37
+ pad_token_id=tokenizer.eos_token_id
38
+ )
39
+ return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
40
+
41
  # -- STYLING --
42
  st.markdown("""
43
  <style>
 
59
  # -- USER INPUT --
60
  user_input = st.text_input("How are you feeling today?", placeholder="Start typing...")
61
 
62
+ # -- CHAT FLOW --
63
  if user_input:
64
  context = "\n".join([f"{s}: {m}" for s, m, _ in st.session_state.history[-4:]])
65
  with st.spinner("TARS is reflecting..."):
66
+ time.sleep(0.8)
67
  if is_high_risk(user_input):
68
  response = "I'm really sorry you're feeling this way. You're not alone β€” please talk to someone you trust or a mental health professional. πŸ’™"
69
  else:
70
+ prompt = f"You are an empathetic AI. Here's the recent conversation:\n{context}\nUser: {user_input}\nAI:"
71
+ response = generate_response(prompt, max_new_tokens=100)
72
+ response = response.split("AI:")[-1].strip()
73
  timestamp = datetime.now().strftime("%H:%M")
74
  st.session_state.history.append(("🧍 You", user_input, timestamp))
75
  st.session_state.history.append(("πŸ€– TARS", response, timestamp))
 
79
  for speaker, msg, time in st.session_state.history:
80
  st.markdown(f"**{speaker} [{time}]:** {msg}")
81
 
82
+ # -- SUMMARY --
83
  if st.button("🧾 Generate Session Summary"):
84
  convo = "\n".join([f"{s}: {m}" for s, m, _ in st.session_state.history])
85
+ summary_prompt = f"Summarize this conversation in 2-3 sentences:\n{convo}\nSummary:"
86
  try:
87
+ summary = generate_response(summary_prompt, max_new_tokens=150, temperature=0.5)
88
+ st.session_state.summary = summary.split("Summary:")[-1].strip()
89
  except Exception as e:
90
+ st.error("❌ Failed to generate summary.")
91
  st.exception(e)
92
 
93
  # -- DISPLAY SUMMARY --