SiraH commited on
Commit
b31680c
·
verified ·
1 Parent(s): d2f09f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -31
app.py CHANGED
@@ -1,34 +1,38 @@
1
- import os
2
- import streamlit as st
3
- import re
4
- from tempfile import NamedTemporaryFile
5
- import time
6
  import pathlib
7
- #from PyPDF2 import PdfReader
8
 
 
9
  from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
10
  from langchain_community.llms import LlamaCpp
11
- from langchain.prompts import PromptTemplate
12
- from langchain.chains import LLMChain
13
  from langchain.callbacks.manager import CallbackManager
14
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
15
  from langchain_community.embeddings import HuggingFaceEmbeddings
16
  from langchain.chains import RetrievalQA
17
  from langchain_community.vectorstores import FAISS
 
 
 
18
  from langchain.chains.question_answering import load_qa_chain
19
  from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
 
20
  from langchain_community.document_loaders import TextLoader
21
  from langchain_community.document_loaders import PyPDFLoader
 
 
 
 
 
 
22
  from langchain.memory import ConversationBufferWindowMemory
 
23
  from langchain.memory import ConversationBufferMemory
24
  from langchain.chains import ConversationalRetrievalChain
25
  from langchain.memory.chat_message_histories.streamlit import StreamlitChatMessageHistory
26
- from langchain.text_splitter import RecursiveCharacterTextSplitter
27
- from langchain_community.llms import HuggingFaceHub
28
-
29
- SECRET_TOKEN = os.getenv("HF_TOKEN")
30
- os.environ["HUGGINGFACEHUB_API_TOKEN"] = SECRET_TOKEN
31
-
32
 
33
  # sidebar contents
34
  with st.sidebar:
@@ -36,10 +40,127 @@ with st.sidebar:
36
  st.markdown('''
37
  ## About
38
  Detail this application:
39
- - LLM model: Llama2-7b-4bit
40
  - Hardware resource : Huggingface space 8 vCPU 32 GB
41
  ''')
42
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def split_docs(documents,chunk_size=1000):
44
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=200)
45
  sp_docs = text_splitter.split_documents(documents)
@@ -47,7 +168,7 @@ def split_docs(documents,chunk_size=1000):
47
 
48
  @st.cache_resource
49
  def load_llama2_llamaCpp():
50
- core_model_name = "llama-2-7b-chat.Q4_0.gguf"
51
  #n_gpu_layers = 32
52
  n_batch = 512
53
  callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
@@ -56,16 +177,18 @@ def load_llama2_llamaCpp():
56
  #n_gpu_layers=n_gpu_layers,
57
  n_batch=n_batch,
58
  callback_manager=callback_manager,
59
- verbose=True,n_ctx = 4096, temperature = 0.1, max_tokens = 128
60
  )
61
  return llm
62
 
63
  def set_custom_prompt():
64
  custom_prompt_template = """ Use the following pieces of information from context to answer the user's question.
65
  If you don't know the answer, don't try to make up an answer.
 
66
  Context : {context}
67
  Question : {question}
68
- Please answer the questions in a concise and straightforward manner.
 
69
  Helpful answer:
70
  """
71
  prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context',
@@ -80,8 +203,6 @@ def load_embeddings():
80
  model_kwargs = {'device': 'cpu'})
81
  return embeddings
82
 
83
-
84
-
85
  def main():
86
  data = []
87
  sp_docs_list = []
@@ -90,15 +211,16 @@ def main():
90
  if "messages" not in st.session_state:
91
  st.session_state.messages = []
92
 
93
- # repo_id = "mistralai/Mistral-7B-Instruct-v0.2"
94
- # llm = HuggingFaceHub(
95
- # repo_id=repo_id, model_kwargs={"temperature": 0.1, "max_length": 128})
96
-
97
-
98
  llm = load_llama2_llamaCpp()
99
  qa_prompt = set_custom_prompt()
100
  embeddings = load_embeddings()
 
 
 
 
 
101
 
 
102
  uploaded_file = st.file_uploader('Choose your .pdf file', type="pdf")
103
  if uploaded_file is not None :
104
  with NamedTemporaryFile(dir='PDF', suffix='.pdf', delete=False) as f:
@@ -117,7 +239,8 @@ def main():
117
  sp_docs = split_docs(documents = data)
118
  st.write(f"This document have {len(sp_docs)} chunks")
119
  sp_docs_list.extend(sp_docs)
120
- try:
 
121
  db = FAISS.from_documents(sp_docs_list, embeddings)
122
  memory = ConversationBufferMemory(memory_key="chat_history",
123
  return_messages=True,
@@ -129,7 +252,19 @@ def main():
129
  retriever = db.as_retriever(search_kwargs = {'k':3}),
130
  return_source_documents = True,
131
  memory = memory,
132
- chain_type_kwargs = {"prompt":qa_prompt})
 
 
 
 
 
 
 
 
 
 
 
 
133
  for message in st.session_state.messages:
134
  with st.chat_message(message["role"]):
135
  st.markdown(message["content"])
@@ -145,6 +280,9 @@ def main():
145
  start = time.time()
146
 
147
  response = qa_chain({'query': query})
 
 
 
148
 
149
  with st.chat_message("assistant"):
150
  st.markdown(response['result'])
@@ -158,6 +296,7 @@ def main():
158
 
159
  with st.expander("See the related documents"):
160
  for count, url in enumerate(response['source_documents']):
 
161
  st.write(str(count+1)+":", url)
162
 
163
  clear_button = st.button("Start new convo")
@@ -165,9 +304,9 @@ def main():
165
  st.session_state.messages = []
166
  qa_chain.memory.chat_memory.clear()
167
 
168
- except:
169
  st.write("Plaese upload your pdf file.")
170
-
171
-
172
  if __name__ == '__main__':
173
  main()
 
1
+ import os
2
+ import streamlit as st
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ import re
 
5
  import pathlib
6
+ from tempfile import NamedTemporaryFile
7
 
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
9
  from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
10
  from langchain_community.llms import LlamaCpp
11
+ from langchain import PromptTemplate, LLMChain
 
12
  from langchain.callbacks.manager import CallbackManager
13
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
14
  from langchain_community.embeddings import HuggingFaceEmbeddings
15
  from langchain.chains import RetrievalQA
16
  from langchain_community.vectorstores import FAISS
17
+ from PyPDF2 import PdfReader
18
+ import os
19
+ import time
20
  from langchain.chains.question_answering import load_qa_chain
21
  from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
22
+
23
  from langchain_community.document_loaders import TextLoader
24
  from langchain_community.document_loaders import PyPDFLoader
25
+ # from langchain.document_loaders import PyPDFLoader
26
+ # from langchain.document_loaders import Docx2txtLoader
27
+ # from langchain.document_loaders.image import UnstructuredImageLoader
28
+ # from langchain.document_loaders import UnstructuredHTMLLoader
29
+ # from langchain.document_loaders import UnstructuredPowerPointLoader
30
+ # from langchain.document_loaders import TextLoader
31
  from langchain.memory import ConversationBufferWindowMemory
32
+
33
  from langchain.memory import ConversationBufferMemory
34
  from langchain.chains import ConversationalRetrievalChain
35
  from langchain.memory.chat_message_histories.streamlit import StreamlitChatMessageHistory
 
 
 
 
 
 
36
 
37
  # sidebar contents
38
  with st.sidebar:
 
40
  st.markdown('''
41
  ## About
42
  Detail this application:
43
+ - LLM model: llama2-7b-chat-4bit
44
  - Hardware resource : Huggingface space 8 vCPU 32 GB
45
  ''')
46
+
47
+ class UploadDoc:
48
+ def __init__(self, path_data):
49
+ self.path_data = path_data
50
+
51
+ def prepare_filetype(self):
52
+ extension_lists = {
53
+ ".docx": [],
54
+ ".pdf": [],
55
+ ".html": [],
56
+ ".png": [],
57
+ ".pptx": [],
58
+ ".txt": [],
59
+ }
60
+
61
+ path_list = []
62
+ for path, subdirs, files in os.walk(self.path_data):
63
+ for name in files:
64
+ path_list.append(os.path.join(path, name))
65
+ #print(os.path.join(path, name))
66
+
67
+ # Loop through the path_list and categorize files
68
+ for filename in path_list:
69
+ file_extension = pathlib.Path(filename).suffix
70
+ #print("File Extension:", file_extension)
71
+
72
+ if file_extension in extension_lists:
73
+ extension_lists[file_extension].append(filename)
74
+ return extension_lists
75
+
76
+ def upload_docx(self, extension_lists):
77
+ #word
78
+ data_docxs = []
79
+ for doc in extension_lists[".docx"]:
80
+ loader = Docx2txtLoader(doc)
81
+ data = loader.load()
82
+ data_docxs.extend(data)
83
+ return data_docxs
84
+
85
+ def upload_pdf(self, extension_lists):
86
+ #pdf
87
+ data_pdf = []
88
+ for doc in extension_lists[".pdf"]:
89
+ loader = PyPDFLoader(doc)
90
+ data = loader.load_and_split()
91
+ data_pdf.extend(data)
92
+ return data_pdf
93
+
94
+ def upload_html(self, extension_lists):
95
+ #html
96
+ data_html = []
97
+ for doc in extension_lists[".html"]:
98
+ loader = UnstructuredHTMLLoader(doc)
99
+ data = loader.load()
100
+ data_html.extend(data)
101
+ return data_html
102
+
103
+ def upload_png_ocr(self, extension_lists):
104
+ #png ocr
105
+ data_png = []
106
+ for doc in extension_lists[".png"]:
107
+ loader = UnstructuredImageLoader(doc)
108
+ data = loader.load()
109
+ data_png.extend(data)
110
+ return data_png
111
+
112
+ def upload_pptx(self, extension_lists):
113
+ #power point
114
+ data_pptx = []
115
+ for doc in extension_lists[".pptx"]:
116
+ loader = UnstructuredPowerPointLoader(doc)
117
+ data = loader.load()
118
+ data_pptx.extend(data)
119
+ return data_pptx
120
+
121
+ def upload_txt(self, extension_lists):
122
+ #txt
123
+ data_txt = []
124
+ for doc in extension_lists[".txt"]:
125
+ loader = TextLoader(doc)
126
+ data = loader.load()
127
+ data_txt.extend(data)
128
+ return data_txt
129
+
130
+ def count_files(self, extension_lists):
131
+ file_extension_counts = {}
132
+ # Count the quantity of each item
133
+ for ext, file_list in extension_lists.items():
134
+ file_extension_counts[ext] = len(file_list)
135
+ return print(f"number of file:{file_extension_counts}")
136
+ # Print the counts
137
+ # for ext, count in file_extension_counts.items():
138
+ # return print(f"{ext}: {count} file")
139
+
140
+ def create_document(self, dataframe=True):
141
+ documents = []
142
+ extension_lists = self.prepare_filetype()
143
+ self.count_files(extension_lists)
144
+
145
+ upload_functions = {
146
+ ".docx": self.upload_docx,
147
+ ".pdf": self.upload_pdf,
148
+ ".html": self.upload_html,
149
+ ".png": self.upload_png_ocr,
150
+ ".pptx": self.upload_pptx,
151
+ ".txt": self.upload_txt,
152
+ }
153
+
154
+ for extension, upload_function in upload_functions.items():
155
+ if len(extension_lists[extension]) > 0:
156
+ if extension == ".xlsx" or extension == ".csv":
157
+ data = upload_function(extension_lists, dataframe)
158
+ else:
159
+ data = upload_function(extension_lists)
160
+ documents.extend(data)
161
+
162
+ return documents
163
+
164
  def split_docs(documents,chunk_size=1000):
165
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=200)
166
  sp_docs = text_splitter.split_documents(documents)
 
168
 
169
  @st.cache_resource
170
  def load_llama2_llamaCpp():
171
+ core_model_name = "phi-2.Q4_K_M.gguf"
172
  #n_gpu_layers = 32
173
  n_batch = 512
174
  callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
 
177
  #n_gpu_layers=n_gpu_layers,
178
  n_batch=n_batch,
179
  callback_manager=callback_manager,
180
+ verbose=True,n_ctx = 4096, temperature = 0.1, max_tokens = 512
181
  )
182
  return llm
183
 
184
  def set_custom_prompt():
185
  custom_prompt_template = """ Use the following pieces of information from context to answer the user's question.
186
  If you don't know the answer, don't try to make up an answer.
187
+
188
  Context : {context}
189
  Question : {question}
190
+
191
+ Only returns the helpful answer below and nothing else.
192
  Helpful answer:
193
  """
194
  prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context',
 
203
  model_kwargs = {'device': 'cpu'})
204
  return embeddings
205
 
 
 
206
  def main():
207
  data = []
208
  sp_docs_list = []
 
211
  if "messages" not in st.session_state:
212
  st.session_state.messages = []
213
 
 
 
 
 
 
214
  llm = load_llama2_llamaCpp()
215
  qa_prompt = set_custom_prompt()
216
  embeddings = load_embeddings()
217
+ #memory = ConversationBufferWindowMemory(k = 0, return_messages=True, input_key= 'question', output_key='answer', memory_key="chat_history")
218
+ #memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
219
+ #doc_chain = load_qa_chain(llm, chain_type="stuff", prompt = qa_prompt)
220
+ #question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
221
+ #embeddings = load_embeddings()
222
 
223
+
224
  uploaded_file = st.file_uploader('Choose your .pdf file', type="pdf")
225
  if uploaded_file is not None :
226
  with NamedTemporaryFile(dir='PDF', suffix='.pdf', delete=False) as f:
 
239
  sp_docs = split_docs(documents = data)
240
  st.write(f"This document have {len(sp_docs)} chunks")
241
  sp_docs_list.extend(sp_docs)
242
+
243
+ try :
244
  db = FAISS.from_documents(sp_docs_list, embeddings)
245
  memory = ConversationBufferMemory(memory_key="chat_history",
246
  return_messages=True,
 
252
  retriever = db.as_retriever(search_kwargs = {'k':3}),
253
  return_source_documents = True,
254
  memory = memory,
255
+ chain_type_kwargs = {"prompt":qa_prompt})
256
+
257
+
258
+ # qa_chain = ConversationalRetrievalChain(
259
+ # retriever =db.as_retriever(search_kwargs={'k':2}),
260
+ # question_generator=question_generator,
261
+ # #condense_question_prompt=CONDENSE_QUESTION_PROMPT,
262
+ # combine_docs_chain=doc_chain,
263
+ # return_source_documents=True,
264
+ # memory = memory,
265
+ # #get_chat_history=lambda h :h
266
+ # )
267
+
268
  for message in st.session_state.messages:
269
  with st.chat_message(message["role"]):
270
  st.markdown(message["content"])
 
280
  start = time.time()
281
 
282
  response = qa_chain({'query': query})
283
+
284
+ #url_list = set([i.metadata['page'] for i in response['source_documents']])
285
+ #print(f"condensed quesion : {question_generator.run({'chat_history': response['chat_history'], 'question' : query})}")
286
 
287
  with st.chat_message("assistant"):
288
  st.markdown(response['result'])
 
296
 
297
  with st.expander("See the related documents"):
298
  for count, url in enumerate(response['source_documents']):
299
+ #url_reg = regex_source(url)
300
  st.write(str(count+1)+":", url)
301
 
302
  clear_button = st.button("Start new convo")
 
304
  st.session_state.messages = []
305
  qa_chain.memory.chat_memory.clear()
306
 
307
+ except :
308
  st.write("Plaese upload your pdf file.")
309
+
310
+
311
  if __name__ == '__main__':
312
  main()