ghuman7 commited on
Commit
3e182cf
·
verified ·
1 Parent(s): 26a42f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -12
app.py CHANGED
@@ -1,22 +1,63 @@
1
  import streamlit as st
2
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
 
 
 
3
 
4
- # Load RAG components
 
 
 
 
 
 
 
5
  tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
6
  retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True)
7
  rag_model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever)
8
 
9
- # Streamlit UI
10
- st.title("RAG-based Q&A")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- query = st.text_input("Enter your question:")
 
 
 
 
 
13
 
14
- if st.button("Generate Answer"):
15
- if query:
16
- # Process the input query and generate a response
17
- inputs = tokenizer(query, return_tensors="pt")
18
- outputs = rag_model.generate(**inputs)
19
- response = tokenizer.batch_decode(outputs, skip_special_tokens=True)
20
- st.write(f"Answer: {response[0]}")
21
  else:
22
- st.write("Please enter a question to get an answer.")
 
1
  import streamlit as st
2
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
3
+ from sentence_transformers import SentenceTransformer
4
+ import faiss
5
+ import torch
6
 
7
+ # Title of the Streamlit app
8
+ st.title("Mental Health Chatbot")
9
+
10
+ # Load a pre-trained sentence transformer model for embedding
11
+ st.write("Loading models... Please wait.")
12
+ sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
13
+
14
+ # Load the RAG model, tokenizer, and retriever
15
  tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
16
  retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True)
17
  rag_model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever)
18
 
19
+ # Sample dialogues related to mental health (replace with actual dataset for production)
20
+ sample_dialogues = [
21
+ "I'm feeling really down lately and don't know what to do.",
22
+ "I just lost my job, and I'm worried about the future.",
23
+ "I'm having trouble sleeping and feeling anxious all the time.",
24
+ "I've been feeling isolated and lonely.",
25
+ "I don't have the energy to do anything, and it's affecting my work."
26
+ ]
27
+
28
+ # Embed the sample dialogues using the sentence transformer model
29
+ embeddings = sentence_model.encode(sample_dialogues, convert_to_tensor=True)
30
+
31
+ # Build FAISS index
32
+ index = faiss.IndexFlatL2(embeddings.shape[1])
33
+ index.add(embeddings.cpu().numpy())
34
+
35
+ # User input
36
+ user_input = st.text_input("How are you feeling today?")
37
+
38
+ # Define response generation function
39
+ def generate_response(query):
40
+ # Embed the query using the sentence transformer
41
+ query_embedding = sentence_model.encode(query, convert_to_tensor=True).cpu().numpy()
42
+
43
+ # Search for the closest dialogue in the index
44
+ D, I = index.search(query_embedding, k=1)
45
+
46
+ # Retrieve the closest dialogue
47
+ closest_dialogue = sample_dialogues[I[0][0]]
48
 
49
+ # Generate response using RAG model
50
+ inputs = tokenizer(closest_dialogue, return_tensors="pt")
51
+ outputs = rag_model.generate(**inputs)
52
+ response = tokenizer.batch_decode(outputs, skip_special_tokens=True)
53
+
54
+ return response[0]
55
 
56
+ # Generate a response when the user submits input
57
+ if st.button("Talk to the Chatbot"):
58
+ if user_input:
59
+ with st.spinner('Generating response...'):
60
+ response = generate_response(user_input)
61
+ st.write(f"Chatbot: {response}")
 
62
  else:
63
+ st.write("Please enter something for the chatbot to respond to.")