Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -15,20 +15,18 @@ from pathlib import Path
|
|
15 |
import chromadb
|
16 |
from unidecode import unidecode
|
17 |
|
18 |
-
from transformers import AutoTokenizer
|
19 |
-
import transformers
|
20 |
-
import torch
|
21 |
-
import tqdm
|
22 |
-
import accelerate
|
23 |
import re
|
24 |
|
25 |
# LLM model to use
|
26 |
llm_model = "mistralai/Mistral-7B-Instruct-v0.2"
|
27 |
|
|
|
|
|
28 |
|
29 |
-
# Load PDF
|
30 |
-
def
|
31 |
-
|
|
|
32 |
pages = []
|
33 |
for loader in loaders:
|
34 |
pages.extend(loader.load())
|
@@ -36,8 +34,7 @@ def load_doc(list_file_path, chunk_size, chunk_overlap):
|
|
36 |
chunk_size=chunk_size,
|
37 |
chunk_overlap=chunk_overlap)
|
38 |
doc_splits = text_splitter.split_documents(pages)
|
39 |
-
return doc_splits
|
40 |
-
|
41 |
|
42 |
# Create vector database
|
43 |
def create_db(splits, collection_name):
|
@@ -51,7 +48,6 @@ def create_db(splits, collection_name):
|
|
51 |
)
|
52 |
return vectordb
|
53 |
|
54 |
-
|
55 |
# Load vector database
|
56 |
def load_db():
|
57 |
embedding = HuggingFaceEmbeddings()
|
@@ -59,7 +55,6 @@ def load_db():
|
|
59 |
embedding_function=embedding)
|
60 |
return vectordb
|
61 |
|
62 |
-
|
63 |
# Initialize langchain LLM chain
|
64 |
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
|
65 |
progress(0.5, desc="Initializing HF Hub...")
|
@@ -90,7 +85,6 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
|
|
90 |
progress(0.9, desc="Done!")
|
91 |
return qa_chain
|
92 |
|
93 |
-
|
94 |
# Generate collection name for vector database
|
95 |
def create_collection_name(filepath):
|
96 |
collection_name = Path(filepath).stem
|
@@ -108,26 +102,21 @@ def create_collection_name(filepath):
|
|
108 |
print('Collection name: ', collection_name)
|
109 |
return collection_name
|
110 |
|
111 |
-
|
112 |
# Initialize database
|
113 |
-
def initialize_database(
|
114 |
-
|
115 |
-
|
116 |
-
collection_name = create_collection_name(
|
117 |
-
progress(0.25, desc="Loading document...")
|
118 |
-
doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
|
119 |
progress(0.5, desc="Generating vector database...")
|
120 |
vector_db = create_db(doc_splits, collection_name)
|
121 |
-
progress(0.9, desc="
|
122 |
return vector_db, collection_name, "Complete!"
|
123 |
|
124 |
-
|
125 |
def initialize_LLM(llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
|
126 |
print("LLM model: ", llm_model)
|
127 |
qa_chain = initialize_llmchain(llm_model, llm_temperature, max_tokens, top_k, vector_db, progress)
|
128 |
return qa_chain, "Complete!"
|
129 |
|
130 |
-
|
131 |
def format_chat_history(message, chat_history):
|
132 |
formatted_chat_history = []
|
133 |
for user_message, bot_message in chat_history:
|
@@ -135,7 +124,6 @@ def format_chat_history(message, chat_history):
|
|
135 |
formatted_chat_history.append(f"Assistant: {bot_message}")
|
136 |
return formatted_chat_history
|
137 |
|
138 |
-
|
139 |
def conversation(qa_chain, message, history):
|
140 |
formatted_chat_history = format_chat_history(message, history)
|
141 |
response = qa_chain({"question": message, "chat_history": formatted_chat_history})
|
@@ -153,15 +141,6 @@ def conversation(qa_chain, message, history):
|
|
153 |
return qa_chain, gr.update(
|
154 |
value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
|
155 |
|
156 |
-
|
157 |
-
def upload_file(file_obj):
|
158 |
-
list_file_path = []
|
159 |
-
for idx, file in enumerate(file_obj):
|
160 |
-
file_path = file_obj.name
|
161 |
-
list_file_path.append(file_path)
|
162 |
-
return list_file_path
|
163 |
-
|
164 |
-
|
165 |
def demo():
|
166 |
with gr.Blocks(theme="base") as demo:
|
167 |
vector_db = gr.State()
|
@@ -178,62 +157,53 @@ def demo():
|
|
178 |
<br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
|
179 |
""")
|
180 |
|
181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
with gr.Row():
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
with gr.Tab("Step 2 - Process document"):
|
187 |
-
with gr.Row():
|
188 |
-
db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value="ChromaDB", type="index",
|
189 |
-
info="Choose your vector database")
|
190 |
-
with gr.Accordion("Advanced options - Document text splitter", open=False):
|
191 |
-
with gr.Row():
|
192 |
-
slider_chunk_size = gr.Slider(minimum=100, maximum=1000, value=600, step=20, label="Chunk size",
|
193 |
-
info="Chunk size", interactive=True)
|
194 |
-
with gr.Row():
|
195 |
-
slider_chunk_overlap = gr.Slider(minimum=10, maximum=200, value=40, step=10, label="Chunk overlap",
|
196 |
-
info="Chunk overlap", interactive=True)
|
197 |
-
with gr.Row():
|
198 |
-
db_progress = gr.Textbox(label="Vector database initialization", value="None")
|
199 |
-
with gr.Row():
|
200 |
-
db_btn = gr.Button("Generate vector database")
|
201 |
-
|
202 |
-
with gr.Tab("Step 3 - Initialize QA chain"):
|
203 |
-
with gr.Row():
|
204 |
-
slider_temperature = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.1, label="Temperature",
|
205 |
-
info="Model temperature", interactive=True)
|
206 |
-
with gr.Row():
|
207 |
-
slider_maxtokens = gr.Slider(minimum=224, maximum=4096, value=1024, step=32, label="Max Tokens",
|
208 |
-
info="Model max tokens", interactive=True)
|
209 |
-
with gr.Row():
|
210 |
-
slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k samples",
|
211 |
-
info="Model top-k samples", interactive=True)
|
212 |
-
with gr.Row():
|
213 |
-
llm_progress = gr.Textbox(value="None", label="QA chain initialization")
|
214 |
-
with gr.Row():
|
215 |
-
qachain_btn = gr.Button("Initialize Question Answering chain")
|
216 |
-
|
217 |
-
with gr.Tab("Step 4 - Chatbot"):
|
218 |
-
chatbot = gr.Chatbot(height=300)
|
219 |
-
with gr.Accordion("Advanced - Document references", open=False):
|
220 |
-
with gr.Row():
|
221 |
-
doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
|
222 |
-
source1_page = gr.Number(label="Page", scale=1)
|
223 |
-
with gr.Row():
|
224 |
-
doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
|
225 |
-
source2_page = gr.Number(label="Page", scale=1)
|
226 |
-
with gr.Row():
|
227 |
-
doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
|
228 |
-
source3_page = gr.Number(label="Page", scale=1)
|
229 |
with gr.Row():
|
230 |
-
|
|
|
231 |
with gr.Row():
|
232 |
-
|
233 |
-
|
|
|
|
|
|
|
|
|
|
|
234 |
|
235 |
db_btn.click(initialize_database, \
|
236 |
-
inputs=[
|
237 |
outputs=[vector_db, collection_name, db_progress])
|
238 |
qachain_btn.click(initialize_LLM, \
|
239 |
inputs=[slider_temperature, slider_maxtokens, slider_topk, vector_db], \
|
|
|
15 |
import chromadb
|
16 |
from unidecode import unidecode
|
17 |
|
|
|
|
|
|
|
|
|
|
|
18 |
import re
|
19 |
|
20 |
# LLM model to use
|
21 |
llm_model = "mistralai/Mistral-7B-Instruct-v0.2"
|
22 |
|
23 |
+
# Directory where PDFs are stored
|
24 |
+
pdf_directory = "data"
|
25 |
|
26 |
+
# Load PDF documents from the specified directory and create doc splits
|
27 |
+
def load_docs_from_directory(directory_path, chunk_size, chunk_overlap):
|
28 |
+
pdf_files = [os.path.join(directory_path, f) for f in os.listdir(directory_path) if f.endswith('.pdf')]
|
29 |
+
loaders = [PyPDFLoader(file) for file in pdf_files]
|
30 |
pages = []
|
31 |
for loader in loaders:
|
32 |
pages.extend(loader.load())
|
|
|
34 |
chunk_size=chunk_size,
|
35 |
chunk_overlap=chunk_overlap)
|
36 |
doc_splits = text_splitter.split_documents(pages)
|
37 |
+
return doc_splits, pdf_files
|
|
|
38 |
|
39 |
# Create vector database
|
40 |
def create_db(splits, collection_name):
|
|
|
48 |
)
|
49 |
return vectordb
|
50 |
|
|
|
51 |
# Load vector database
|
52 |
def load_db():
|
53 |
embedding = HuggingFaceEmbeddings()
|
|
|
55 |
embedding_function=embedding)
|
56 |
return vectordb
|
57 |
|
|
|
58 |
# Initialize langchain LLM chain
|
59 |
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
|
60 |
progress(0.5, desc="Initializing HF Hub...")
|
|
|
85 |
progress(0.9, desc="Done!")
|
86 |
return qa_chain
|
87 |
|
|
|
88 |
# Generate collection name for vector database
|
89 |
def create_collection_name(filepath):
|
90 |
collection_name = Path(filepath).stem
|
|
|
102 |
print('Collection name: ', collection_name)
|
103 |
return collection_name
|
104 |
|
|
|
105 |
# Initialize database
|
106 |
+
def initialize_database(directory_path, chunk_size, chunk_overlap, progress=gr.Progress()):
|
107 |
+
progress(0.1, desc="Loading documents from directory...")
|
108 |
+
doc_splits, pdf_files = load_docs_from_directory(directory_path, chunk_size, chunk_overlap)
|
109 |
+
collection_name = create_collection_name(pdf_files[0])
|
|
|
|
|
110 |
progress(0.5, desc="Generating vector database...")
|
111 |
vector_db = create_db(doc_splits, collection_name)
|
112 |
+
progress(0.9, desc="Database initialization complete!")
|
113 |
return vector_db, collection_name, "Complete!"
|
114 |
|
|
|
115 |
def initialize_LLM(llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
|
116 |
print("LLM model: ", llm_model)
|
117 |
qa_chain = initialize_llmchain(llm_model, llm_temperature, max_tokens, top_k, vector_db, progress)
|
118 |
return qa_chain, "Complete!"
|
119 |
|
|
|
120 |
def format_chat_history(message, chat_history):
|
121 |
formatted_chat_history = []
|
122 |
for user_message, bot_message in chat_history:
|
|
|
124 |
formatted_chat_history.append(f"Assistant: {bot_message}")
|
125 |
return formatted_chat_history
|
126 |
|
|
|
127 |
def conversation(qa_chain, message, history):
|
128 |
formatted_chat_history = format_chat_history(message, history)
|
129 |
response = qa_chain({"question": message, "chat_history": formatted_chat_history})
|
|
|
141 |
return qa_chain, gr.update(
|
142 |
value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
|
143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
def demo():
|
145 |
with gr.Blocks(theme="base") as demo:
|
146 |
vector_db = gr.State()
|
|
|
157 |
<br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
|
158 |
""")
|
159 |
|
160 |
+
gr.Markdown("<h4>Step 1 - Process and Load Documents from 'data' Folder</h4>")
|
161 |
+
with gr.Row():
|
162 |
+
slider_chunk_size = gr.Slider(minimum=100, maximum=1000, value=600, step=20, label="Chunk size",
|
163 |
+
info="Chunk size", interactive=True)
|
164 |
+
with gr.Row():
|
165 |
+
slider_chunk_overlap = gr.Slider(minimum=10, maximum=200, value=40, step=10, label="Chunk overlap",
|
166 |
+
info="Chunk overlap", interactive=True)
|
167 |
+
with gr.Row():
|
168 |
+
db_progress = gr.Textbox(label="Vector database initialization", value="None")
|
169 |
+
with gr.Row():
|
170 |
+
db_btn = gr.Button("Generate vector database")
|
171 |
+
|
172 |
+
gr.Markdown("<h4>Step 2 - Initialize QA chain</h4>")
|
173 |
+
with gr.Row():
|
174 |
+
slider_temperature = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.1, label="Temperature",
|
175 |
+
info="Model temperature", interactive=True)
|
176 |
+
with gr.Row():
|
177 |
+
slider_maxtokens = gr.Slider(minimum=224, maximum=4096, value=1024, step=32, label="Max Tokens",
|
178 |
+
info="Model max tokens", interactive=True)
|
179 |
+
with gr.Row():
|
180 |
+
slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k samples",
|
181 |
+
info="Model top-k samples", interactive=True)
|
182 |
+
with gr.Row():
|
183 |
+
llm_progress = gr.Textbox(value="None", label="QA chain initialization")
|
184 |
+
with gr.Row():
|
185 |
+
qachain_btn = gr.Button("Initialize Question Answering chain")
|
186 |
+
|
187 |
+
gr.Markdown("<h4>Step 3 - Chatbot</h4>")
|
188 |
+
chatbot = gr.Chatbot(height=300)
|
189 |
+
with gr.Accordion("Advanced - Document references", open=False):
|
190 |
with gr.Row():
|
191 |
+
doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
|
192 |
+
source1_page = gr.Number(label="Page", scale=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
with gr.Row():
|
194 |
+
doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
|
195 |
+
source2_page = gr.Number(label="Page", scale=1)
|
196 |
with gr.Row():
|
197 |
+
doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
|
198 |
+
source3_page = gr.Number(label="Page", scale=1)
|
199 |
+
with gr.Row():
|
200 |
+
msg = gr.Textbox(placeholder="Type message (e.g. 'What is this document about?')", container=True)
|
201 |
+
with gr.Row():
|
202 |
+
submit_btn = gr.Button("Submit message")
|
203 |
+
clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
|
204 |
|
205 |
db_btn.click(initialize_database, \
|
206 |
+
inputs=[pdf_directory, slider_chunk_size, slider_chunk_overlap], \
|
207 |
outputs=[vector_db, collection_name, db_progress])
|
208 |
qachain_btn.click(initialize_LLM, \
|
209 |
inputs=[slider_temperature, slider_maxtokens, slider_topk, vector_db], \
|