Sakil commited on
Commit
81a587f
·
1 Parent(s): b26e8d5

code added

Browse files
Files changed (1) hide show
  1. app.py +81 -0
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain.document_loaders import PyPDFLoader, DirectoryLoader
3
+ from langchain import PromptTemplate
4
+ from langchain.embeddings import HuggingFaceEmbeddings
5
+ from langchain.vectorstores import FAISS
6
+ from langchain.llms import CTransformers
7
+ from langchain.chains import RetrievalQA
8
+ import chainlit as cl
9
+
10
+ DB_FAISS_PATH = 'vectorstore/db_faiss'
11
+
12
+ custom_prompt_template = """Use the following pieces of information to answer the user's question.
13
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
14
+
15
+ Context: {context}
16
+ Question: {question}
17
+
18
+ Only return the helpful answer below and nothing else.
19
+ Helpful answer:
20
+ """
21
+
22
+ def set_custom_prompt():
23
+ """
24
+ Prompt template for QA retrieval for each vectorstore
25
+ """
26
+ prompt = PromptTemplate(template=custom_prompt_template,
27
+ input_variables=['context', 'question'])
28
+ return prompt
29
+
30
+ # Retrieval QA Chain
31
+ def retrieval_qa_chain(llm, prompt, db):
32
+ qa_chain = RetrievalQA.from_chain_type(llm=llm,
33
+ chain_type='stuff',
34
+ retriever=db.as_retriever(search_kwargs={'k': 2}),
35
+ return_source_documents=True,
36
+ chain_type_kwargs={'prompt': prompt}
37
+ )
38
+ return qa_chain
39
+
40
+ # Loading the model
41
+ def load_llm():
42
+ # Load the locally downloaded model here
43
+ llm = CTransformers(
44
+ model="llama-2-7b-chat.ggmlv3.q8_0.bin",
45
+ model_type="llama",
46
+ max_new_tokens=512,
47
+ temperature=0.5
48
+ )
49
+ return llm
50
+
51
+ # QA Model Function
52
+ def qa_bot():
53
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
54
+ model_kwargs={'device': 'cpu'})
55
+ db = FAISS.load_local(DB_FAISS_PATH, embeddings)
56
+ llm = load_llm()
57
+ qa_prompt = set_custom_prompt()
58
+ qa = retrieval_qa_chain(llm, qa_prompt, db)
59
+
60
+ return qa
61
+
62
+ def main():
63
+ st.title("AI Chatbot with Streamlit")
64
+
65
+ qa_result = qa_bot()
66
+
67
+ user_input = st.text_input("Enter your question:")
68
+
69
+ if st.button("Ask"):
70
+ response = qa_result({'query': user_input})
71
+ answer = response["result"]
72
+ sources = response["source_documents"]
73
+
74
+ st.write("Answer:", answer)
75
+ if sources:
76
+ st.write("Sources:", sources)
77
+ else:
78
+ st.write("No sources found")
79
+
80
+ if __name__ == "__main__":
81
+ main()