Fecalisboa commited on
Commit
2ab1cce
·
verified ·
1 Parent(s): 81255d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -218
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import gradio as gr
2
  import os
3
-
4
  from langchain_community.document_loaders import PyPDFLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain_community.vectorstores import Chroma
@@ -10,6 +9,7 @@ from langchain_community.llms import HuggingFacePipeline
10
  from langchain.chains import ConversationChain
11
  from langchain.memory import ConversationBufferMemory
12
  from langchain_community.llms import HuggingFaceEndpoint
 
13
 
14
  from pathlib import Path
15
  import chromadb
@@ -22,28 +22,10 @@ import tqdm
22
  import accelerate
23
  import re
24
 
25
- # LlamaParse import
26
- from llama_parse import LlamaParse
27
- import asyncio
28
- from llama_index.core.async_utils import DEFAULT_NUM_WORKERS, run_jobs
29
- from llama_index.core.base.response.schema import PydanticResponse
30
- from llama_index.core.bridge.pydantic import BaseModel, Field, ValidationError
31
- from llama_index.core.callbacks.base import CallbackManager
32
- from llama_index.core.llms.llm import LLM
33
- from llama_index.core.node_parser.interface import NodeParser
34
- from llama_index.core.schema import BaseNode, Document, IndexNode, TextNode
35
- from llama_index.core.utils import get_tqdm_iterable
36
-
37
  from io import StringIO
38
  from typing import Any, Callable, List, Optional
39
 
40
  import pandas as pd
41
- from llama_index.core.node_parser.relational.base_element import (
42
- # BaseElementNodeParser,
43
- Element,
44
- )
45
- from llama_index.core.schema import BaseNode, TextNode
46
-
47
 
48
  # Obtenha o token da variável de ambiente
49
  api_token = os.getenv("HF_TOKEN")
@@ -58,7 +40,7 @@ def load_doc(list_file_path, chunk_size, chunk_overlap):
58
  pages = []
59
  for loader in loaders:
60
  pages.extend(loader.load())
61
- text_splitter = RecursiveCharacterTextSplitter(chunk_size = chunk_size, chunk_overlap = chunk_overlap)
62
  doc_splits = text_splitter.split_documents(pages)
63
  return doc_splits
64
 
@@ -87,164 +69,13 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
87
 
88
  progress(0.5, desc="Initializing HF Hub...")
89
 
90
- if llm_model == "mistralai/Mistral-7B-Instruct-v0.2":
91
- llm = HuggingFaceEndpoint(
92
- repo_id=llm_model,
93
- huggingfacehub_api_token = api_token,
94
- temperature = temperature,
95
- max_new_tokens = max_tokens,
96
- top_k = top_k,
97
- )
98
- else:
99
- llm = HuggingFaceEndpoint(
100
- huggingfacehub_api_token = api_token,
101
- repo_id=llm_model,
102
- temperature = temperature,
103
- max_new_tokens = max_tokens,
104
- top_k = top_k,
105
- )
106
-
107
- progress(0.75, desc="Defining buffer memory...")
108
- memory = ConversationBufferMemory(
109
- memory_key="chat_history",
110
- output_key='answer',
111
- return_messages=True
112
  )
113
- retriever=vector_db.as_retriever()
114
- progress(0.8, desc="Defining retrieval chain...")
115
- qa_chain = ConversationalRetrievalChain.from_llm(
116
- llm,
117
- retriever=retriever,
118
- chain_type="stuff",
119
- memory=memory,
120
- return_source_documents=True,
121
- verbose=False,
122
- )
123
- progress(0.9, desc="Done!")
124
- return qa_chain
125
-
126
- # Generate collection name for vector database
127
- def create_collection_name(filepath):
128
- collection_name = Path(filepath).stem
129
- collection_name = collection_name.replace(" ","-")
130
- collection_name = unidecode(collection_name)
131
- collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
132
- collection_name = collection_name[:50]
133
- if len(collection_name) < 3:
134
- collection_name = collection_name + 'xyz'
135
- if not collection_name[0].isalnum():
136
- collection_name = 'A' + collection_name[1:]
137
- if not collection_name[-1].isalnum():
138
- collection_name = collection_name[:-1] + 'Z'
139
- print('Filepath: ', filepath)
140
- print('Collection name: ', collection_name)
141
- return collection_name
142
-
143
- # Initialize database
144
- def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
145
- list_file_path = [x.name for x in list_file_obj if x is not None]
146
- progress(0.1, desc="Creating collection name...")
147
- collection_name = create_collection_name(list_file_path[0])
148
- progress(0.25, desc="Loading document...")
149
- doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
150
- progress(0.5, desc="Generating vector database...")
151
- vector_db = create_db(doc_splits, collection_name)
152
- progress(0.9, desc="Done!")
153
- return vector_db, collection_name, "Complete!"
154
-
155
- def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
156
- llm_name = list_llm[llm_option]
157
- print("llm_name: ",llm_name)
158
- qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
159
- return qa_chain, "Complete!"
160
-
161
- def format_chat_history(message, chat_history):
162
- formatted_chat_history = []
163
- for user_message, bot_message in chat_history:
164
- formatted_chat_history.append(f"User: {user_message}")
165
- formatted_chat_history.append(f"Assistant: {bot_message}")
166
- return formatted_chat_history
167
-
168
-
169
- def conversation(qa_chain, message, history):
170
- formatted_chat_history = format_chat_history(message, history)
171
-
172
- response = qa_chain({"question": message, "chat_history": formatted_chat_history})
173
- response_answer = response["answer"]
174
- if response_answer.find("Helpful Answer:") != -1:
175
- response_answer = response_answer.split("Helpful Answer:")[-1]
176
- response_sources = response["source_documents"]
177
- response_source1 = response_sources[0].page_content.strip()
178
- response_source2 = response_sources[1].page_content.strip()
179
- response_source3 = response_sources[2].page_content.strip()
180
- response_source1_page = response_sources[0].metadata["page"] + 1
181
- response_source2_page = response_sources[1].metadata["page"] + 1
182
- response_source3_page = response_sources[2].metadata["page"] + 1
183
-
184
- new_history = history + [(message, response_answer)]
185
- return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
186
-
187
- def upload_file(file_obj):
188
- list_file_path = []
189
- for idx, file in enumerate(file_obj):
190
- file_path = file_obj.name
191
- list_file_path.append(file_path)
192
- return list_file_path
193
-
194
- list_llm = ["mistralai/Miceli", "mistralai/Mistral-7B-Instruct-v0.3"]
195
- list_llm_simple = [os.path.basename(llm) for llm in list_llm]
196
-
197
- # Load PDF document and create doc splits
198
- def load_doc(list_file_path, chunk_size, chunk_overlap):
199
- loaders = [PyPDFLoader(x) for x in list_file_path]
200
- pages = []
201
- for loader in loaders:
202
- pages.extend(loader.load())
203
- text_splitter = RecursiveCharacterTextSplitter(chunk_size = chunk_size, chunk_overlap = chunk_overlap)
204
- doc_splits = text_splitter.split_documents(pages)
205
- return doc_splits
206
-
207
- # Create vector database
208
- def create_db(splits, collection_name):
209
- embedding = HuggingFaceEmbeddings()
210
- new_client = chromadb.EphemeralClient()
211
- vectordb = Chroma.from_documents(
212
- documents=splits,
213
- embedding=embedding,
214
- client=new_client,
215
- collection_name=collection_name,
216
- )
217
- return vectordb
218
-
219
- # Load vector database
220
- def load_db():
221
- embedding = HuggingFaceEmbeddings()
222
- vectordb = Chroma(
223
- embedding_function=embedding)
224
- return vectordb
225
-
226
- # Initialize langchain LLM chain
227
- def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
228
- progress(0.1, desc="Initializing HF tokenizer...")
229
-
230
- progress(0.5, desc="Initializing HF Hub...")
231
-
232
- if llm_model == "mistralai/Mistral-7B-Instruct-v0.2":
233
- llm = HuggingFaceEndpoint(
234
- repo_id=llm_model,
235
- huggingfacehub_api_token = api_token,
236
- temperature = temperature,
237
- max_new_tokens = max_tokens,
238
- top_k = top_k,
239
- )
240
- else:
241
- llm = HuggingFaceEndpoint(
242
- huggingfacehub_api_token = api_token,
243
- repo_id=llm_model,
244
- temperature = temperature,
245
- max_new_tokens = max_tokens,
246
- top_k = top_k,
247
- )
248
 
249
  progress(0.75, desc="Defining buffer memory...")
250
  memory = ConversationBufferMemory(
@@ -252,7 +83,7 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
252
  output_key='answer',
253
  return_messages=True
254
  )
255
- retriever=vector_db.as_retriever()
256
  progress(0.8, desc="Defining retrieval chain...")
257
  qa_chain = ConversationalRetrievalChain.from_llm(
258
  llm,
@@ -294,37 +125,9 @@ def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Pr
294
  progress(0.9, desc="Done!")
295
  return vector_db, collection_name, "Complete!"
296
 
297
- # Initialize LlamaIndex parsing
298
- def initialize_llama_index(file_obj):
299
- documents = LlamaParse(result_type="markdown", api_key=api_token).load_data(file_obj[0].name)
300
- node_parser = MarkdownElementNodeParser(llm=None, num_workers=8)
301
- nodes = node_parser.get_nodes_from_documents(documents)
302
- base_nodes, objects = node_parser.get_nodes_and_objects(nodes)
303
-
304
- # Usando SimpleVectorStore para criar um índice vetorial
305
- vector_store = SimpleVectorStore()
306
- for node in base_nodes + objects:
307
- vector_store.add(node)
308
-
309
- # Criando um recuperador a partir do índice vetorial
310
- index_ret = VectorIndexRetriever(vector_store=vector_store, top_k=15)
311
-
312
- # Configurando o motor de consulta
313
- reranker = FlagEmbeddingReranker(
314
- top_n=5,
315
- model="BAAI/bge-reranker-large"
316
- )
317
- recursive_query_engine = RetrieverQueryEngine(
318
- retriever=index_ret,
319
- node_postprocessors=[reranker],
320
- verbose=False
321
- )
322
-
323
- return recursive_query_engine, "LlamaIndex parsing complete"
324
-
325
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
326
  llm_name = list_llm[llm_option]
327
- print("llm_name: ",llm_name)
328
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
329
  return qa_chain, "Complete!"
330
 
@@ -364,7 +167,6 @@ def demo():
364
  vector_db = gr.State()
365
  qa_chain = gr.State()
366
  collection_name = gr.State()
367
- llama_index_engine = gr.State()
368
 
369
  gr.Markdown(
370
  """<center><h2>PDF-based chatbot</center></h2>
@@ -407,13 +209,7 @@ def demo():
407
  with gr.Row():
408
  qachain_btn = gr.Button("Initialize Question Answering chain")
409
 
410
- with gr.Tab("Step 4 - LlamaIndex parsing"):
411
- with gr.Row():
412
- llama_index_btn = gr.Button("Parse with LlamaIndex")
413
- with gr.Row():
414
- llama_index_progress = gr.Textbox(label="LlamaIndex parsing status", value="None")
415
-
416
- with gr.Tab("Step 5 - Chatbot"):
417
  chatbot = gr.Chatbot(height=300)
418
  with gr.Accordion("Advanced - Document references", open=False):
419
  with gr.Row():
@@ -441,9 +237,6 @@ def demo():
441
  inputs=None,
442
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
443
  queue=False)
444
- llama_index_btn.click(initialize_llama_index,
445
- inputs=[document],
446
- outputs=[llama_index_engine, llama_index_progress])
447
 
448
  # Chatbot events
449
  msg.submit(conversation,
 
1
  import gradio as gr
2
  import os
 
3
  from langchain_community.document_loaders import PyPDFLoader
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain_community.vectorstores import Chroma
 
9
  from langchain.chains import ConversationChain
10
  from langchain.memory import ConversationBufferMemory
11
  from langchain_community.llms import HuggingFaceEndpoint
12
+ from huggingface_hub import login
13
 
14
  from pathlib import Path
15
  import chromadb
 
22
  import accelerate
23
  import re
24
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  from io import StringIO
26
  from typing import Any, Callable, List, Optional
27
 
28
  import pandas as pd
 
 
 
 
 
 
29
 
30
  # Obtenha o token da variável de ambiente
31
  api_token = os.getenv("HF_TOKEN")
 
40
  pages = []
41
  for loader in loaders:
42
  pages.extend(loader.load())
43
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
44
  doc_splits = text_splitter.split_documents(pages)
45
  return doc_splits
46
 
 
69
 
70
  progress(0.5, desc="Initializing HF Hub...")
71
 
72
+ llm = HuggingFaceEndpoint(
73
+ repo_id=llm_model,
74
+ huggingfacehub_api_token=api_token,
75
+ temperature=temperature,
76
+ max_new_tokens=max_tokens,
77
+ top_k=top_k,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  progress(0.75, desc="Defining buffer memory...")
81
  memory = ConversationBufferMemory(
 
83
  output_key='answer',
84
  return_messages=True
85
  )
86
+ retriever = vector_db.as_retriever()
87
  progress(0.8, desc="Defining retrieval chain...")
88
  qa_chain = ConversationalRetrievalChain.from_llm(
89
  llm,
 
125
  progress(0.9, desc="Done!")
126
  return vector_db, collection_name, "Complete!"
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
129
  llm_name = list_llm[llm_option]
130
+ print("llm_name: ", llm_name)
131
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
132
  return qa_chain, "Complete!"
133
 
 
167
  vector_db = gr.State()
168
  qa_chain = gr.State()
169
  collection_name = gr.State()
 
170
 
171
  gr.Markdown(
172
  """<center><h2>PDF-based chatbot</center></h2>
 
209
  with gr.Row():
210
  qachain_btn = gr.Button("Initialize Question Answering chain")
211
 
212
+ with gr.Tab("Step 4 - Chatbot"):
 
 
 
 
 
 
213
  chatbot = gr.Chatbot(height=300)
214
  with gr.Accordion("Advanced - Document references", open=False):
215
  with gr.Row():
 
237
  inputs=None,
238
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
239
  queue=False)
 
 
 
240
 
241
  # Chatbot events
242
  msg.submit(conversation,