ibibek commited on
Commit
9e7b1cc
·
1 Parent(s): 568d334

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +98 -0
  2. requirements.txt +8 -0
  3. streaming.py +11 -0
  4. utils.py +55 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import utils
3
+ import streamlit as st
4
+ from streaming import StreamHandler
5
+
6
+ from langchain.chat_models import ChatOpenAI
7
+ from langchain.document_loaders import PyPDFLoader
8
+ from langchain.memory import ConversationBufferMemory
9
+ from langchain.embeddings import HuggingFaceEmbeddings
10
+ from langchain.chains import ConversationalRetrievalChain
11
+ from langchain.vectorstores import DocArrayInMemorySearch
12
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
13
+ from langchain.embeddings import OpenAIEmbeddings
14
+
15
+ st.header('Chatbot for AEO ')
16
+ st.write('Please upload the necessary files about AEO in the sidebar and ask questions in the chat.')
17
+
18
+
19
+
20
+
21
+ class CustomDataChatbot:
22
+
23
+ def __init__(self):
24
+ self.oepn_ai_key = utils.configure_openai_api_key()
25
+ self.openai_model = "gpt-3.5-turbo"
26
+
27
+ def save_file(self, file):
28
+ folder = 'tmp'
29
+ if not os.path.exists(folder):
30
+ os.makedirs(folder)
31
+
32
+ file_path = f'./{folder}/{file.name}'
33
+ with open(file_path, 'wb') as f:
34
+ f.write(file.getvalue())
35
+ return file_path
36
+
37
+ @st.spinner('Analyzing documents..')
38
+ def setup_qa_chain(self, uploaded_files):
39
+ # Load documents
40
+ docs = []
41
+ for file in uploaded_files:
42
+ file_path = self.save_file(file)
43
+ loader = PyPDFLoader(file_path)
44
+ docs.extend(loader.load())
45
+
46
+ # Split documents
47
+ text_splitter = RecursiveCharacterTextSplitter(
48
+ chunk_size=1500,
49
+ chunk_overlap=200
50
+ )
51
+ splits = text_splitter.split_documents(docs)
52
+
53
+ # Create embeddings and store in vectordb
54
+
55
+ embeddings = OpenAIEmbeddings(openai_api_key = self.oepn_ai_key)
56
+ vectordb = DocArrayInMemorySearch.from_documents(splits, embeddings)
57
+
58
+ # Define retriever
59
+ retriever = vectordb.as_retriever(
60
+ search_type='mmr',
61
+ search_kwargs={'k':2, 'fetch_k':4}
62
+ )
63
+
64
+ # Setup memory for contextual conversation
65
+ memory = ConversationBufferMemory(
66
+ memory_key='chat_history',
67
+ return_messages=True
68
+ )
69
+
70
+ # Setup LLM and QA chain
71
+ llm = ChatOpenAI(model_name=self.openai_model, temperature=0, streaming=True)
72
+ qa_chain = ConversationalRetrievalChain.from_llm(llm, retriever=retriever, memory=memory, verbose=True)
73
+ return qa_chain
74
+
75
+ @utils.enable_chat_history
76
+ def main(self):
77
+
78
+ # User Inputs
79
+ uploaded_files = st.sidebar.file_uploader(label='Upload PDF files', type=['pdf'], accept_multiple_files=True)
80
+ if not uploaded_files:
81
+ st.error("Please upload PDF documents to continue!")
82
+ st.stop()
83
+
84
+ user_query = st.chat_input(placeholder="Ask me anything!")
85
+
86
+ if uploaded_files and user_query:
87
+ qa_chain = self.setup_qa_chain(uploaded_files)
88
+
89
+ utils.display_msg(user_query, 'user')
90
+
91
+ with st.chat_message("assistant"):
92
+ st_cb = StreamHandler(st.empty())
93
+ response = qa_chain.run(user_query, callbacks=[st_cb])
94
+ st.session_state.messages.append({"role": "assistant", "content": response})
95
+
96
+ if __name__ == "__main__":
97
+ obj = CustomDataChatbot()
98
+ obj.main()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ langchain==0.0.228
2
+ openai==0.27.8
3
+ streamlit==1.24.0
4
+ duckduckgo-search==3.8.3
5
+ pypdf==3.12.0
6
+ sentence-transformers==2.2.2
7
+ docarray==0.32.1
8
+ tiktoken
streaming.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.callbacks.base import BaseCallbackHandler
2
+
3
+ class StreamHandler(BaseCallbackHandler):
4
+
5
+ def __init__(self, container, initial_text=""):
6
+ self.container = container
7
+ self.text = initial_text
8
+
9
+ def on_llm_new_token(self, token: str, **kwargs):
10
+ self.text += token
11
+ self.container.markdown(self.text)
utils.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import streamlit as st
4
+
5
+ #decorator
6
+ def enable_chat_history(func):
7
+ if os.environ.get("OPENAI_API_KEY"):
8
+
9
+ # to clear chat history after swtching chatbot
10
+ current_page = func.__qualname__
11
+ if "current_page" not in st.session_state:
12
+ st.session_state["current_page"] = current_page
13
+ if st.session_state["current_page"] != current_page:
14
+ try:
15
+ st.cache_resource.clear()
16
+ del st.session_state["current_page"]
17
+ del st.session_state["messages"]
18
+ except:
19
+ pass
20
+
21
+ # to show chat history on ui
22
+ if "messages" not in st.session_state:
23
+ st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
24
+ for msg in st.session_state["messages"]:
25
+ st.chat_message(msg["role"]).write(msg["content"])
26
+
27
+ def execute(*args, **kwargs):
28
+ func(*args, **kwargs)
29
+ return execute
30
+
31
+ def display_msg(msg, author):
32
+ """Method to display message on the UI
33
+
34
+ Args:
35
+ msg (str): message to display
36
+ author (str): author of the message -user/assistant
37
+ """
38
+ st.session_state.messages.append({"role": author, "content": msg})
39
+ st.chat_message(author).write(msg)
40
+
41
+ def configure_openai_api_key():
42
+ openai_api_key = st.sidebar.text_input(
43
+ label="OpenAI API Key",
44
+ type="password",
45
+ value=st.session_state['OPENAI_API_KEY'] if 'OPENAI_API_KEY' in st.session_state else '',
46
+ placeholder="sk-..."
47
+ )
48
+ if openai_api_key:
49
+ st.session_state['OPENAI_API_KEY'] = openai_api_key
50
+ os.environ['OPENAI_API_KEY'] = openai_api_key
51
+ else:
52
+ st.error("Please add your OpenAI API key to continue.")
53
+ st.info("Obtain your key from this link: https://platform.openai.com/account/api-keys")
54
+ st.stop()
55
+ return openai_api_key