ghuman7 commited on
Commit
57642d9
·
verified ·
1 Parent(s): 3e182cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -55
app.py CHANGED
@@ -1,63 +1,32 @@
1
  import streamlit as st
2
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
3
- from sentence_transformers import SentenceTransformer
4
- import faiss
5
- import torch
6
 
7
- # Title of the Streamlit app
8
- st.title("Mental Health Chatbot")
9
-
10
- # Load a pre-trained sentence transformer model for embedding
11
- st.write("Loading models... Please wait.")
12
- sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
13
-
14
- # Load the RAG model, tokenizer, and retriever
15
- tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
16
- retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True)
17
- rag_model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever)
18
 
19
- # Sample dialogues related to mental health (replace with actual dataset for production)
20
- sample_dialogues = [
21
- "I'm feeling really down lately and don't know what to do.",
22
- "I just lost my job, and I'm worried about the future.",
23
- "I'm having trouble sleeping and feeling anxious all the time.",
24
- "I've been feeling isolated and lonely.",
25
- "I don't have the energy to do anything, and it's affecting my work."
26
- ]
27
 
28
- # Embed the sample dialogues using the sentence transformer model
29
- embeddings = sentence_model.encode(sample_dialogues, convert_to_tensor=True)
30
-
31
- # Build FAISS index
32
- index = faiss.IndexFlatL2(embeddings.shape[1])
33
- index.add(embeddings.cpu().numpy())
34
 
35
  # User input
36
- user_input = st.text_input("How are you feeling today?")
37
-
38
- # Define response generation function
39
- def generate_response(query):
40
- # Embed the query using the sentence transformer
41
- query_embedding = sentence_model.encode(query, convert_to_tensor=True).cpu().numpy()
42
-
43
- # Search for the closest dialogue in the index
44
- D, I = index.search(query_embedding, k=1)
45
-
46
- # Retrieve the closest dialogue
47
- closest_dialogue = sample_dialogues[I[0][0]]
48
-
49
- # Generate response using RAG model
50
- inputs = tokenizer(closest_dialogue, return_tensors="pt")
51
- outputs = rag_model.generate(**inputs)
52
- response = tokenizer.batch_decode(outputs, skip_special_tokens=True)
53
-
54
- return response[0]
55
-
56
- # Generate a response when the user submits input
57
- if st.button("Talk to the Chatbot"):
58
- if user_input:
59
- with st.spinner('Generating response...'):
60
- response = generate_response(user_input)
61
- st.write(f"Chatbot: {response}")
62
  else:
63
- st.write("Please enter something for the chatbot to respond to.")
 
1
  import streamlit as st
2
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
 
 
 
3
 
4
+ # Load the RAG model components
5
+ @st.cache_resource
6
+ def load_rag_model():
7
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
8
+ retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True)
9
+ rag_model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever)
10
+ return tokenizer, retriever, rag_model
 
 
 
 
11
 
12
+ tokenizer, retriever, rag_model = load_rag_model()
 
 
 
 
 
 
 
13
 
14
+ # Streamlit UI for Mental Health Chatbot
15
+ st.title("Mental Health Chatbot")
16
+ st.write("""
17
+ This chatbot uses a pre-trained RAG model to provide responses to mental health-related queries.
18
+ Please note that this is an AI-based tool and is not a substitute for professional mental health support.
19
+ """)
20
 
21
  # User input
22
+ query = st.text_input("How can I help you today?")
23
+
24
+ if st.button("Get Response"):
25
+ if query:
26
+ # Generate a response using the RAG model
27
+ inputs = tokenizer(query, return_tensors="pt")
28
+ outputs = rag_model.generate(**inputs)
29
+ response = tokenizer.batch_decode(outputs, skip_special_tokens=True)
30
+ st.write(f"**Response:** {response[0]}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  else:
32
+ st.write("Please enter a query to get a response.")