JaganathC commited on
Commit
a70773f
·
verified ·
1 Parent(s): 2e89d52

Upload 4 files

Browse files
Files changed (4) hide show
  1. app (3).py +343 -0
  2. requirements (1).txt +14 -0
  3. requirements-dev.txt +2 -0
  4. retrieval.py +122 -0
app (3).py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PDF-based chatbot with Retrieval-Augmented Generation
3
+ """
4
+
5
+ import os
6
+ import gradio as gr
7
+
8
+ from dotenv import load_dotenv
9
+
10
+ import indexing
11
+ import retrieval
12
+
13
+
14
+ # default_persist_directory = './chroma_HF/'
15
+ list_llm = [
16
+ "mistralai/Mistral-7B-Instruct-v0.3",
17
+ "microsoft/Phi-3.5-mini-instruct",
18
+ "meta-llama/Llama-3.1-8B-Instruct",
19
+ "meta-llama/Llama-3.2-3B-Instruct",
20
+ "meta-llama/Llama-3.2-1B-Instruct",
21
+ "HuggingFaceTB/SmolLM2-1.7B-Instruct",
22
+ "HuggingFaceH4/zephyr-7b-beta",
23
+ "HuggingFaceH4/zephyr-7b-gemma-v0.1",
24
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
25
+ "google/gemma-2-2b-it",
26
+ "google/gemma-2-9b-it",
27
+ "Qwen/Qwen2.5-1.5B-Instruct",
28
+ "Qwen/Qwen2.5-3B-Instruct",
29
+ "Qwen/Qwen2.5-7B-Instruct",
30
+ ]
31
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
32
+
33
+
34
+ # Load environment file - HuggingFace API key
35
+ def retrieve_api():
36
+ """Retrieve HuggingFace API Key"""
37
+ _ = load_dotenv()
38
+ global huggingfacehub_api_token
39
+ huggingfacehub_api_token = os.environ.get("HUGGINGFACE_API_KEY")
40
+
41
+
42
+ # Initialize database
43
+ def initialize_database(
44
+ list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()
45
+ ):
46
+ """Initialize database"""
47
+
48
+ # Create list of documents (when valid)
49
+ list_file_path = [x.name for x in list_file_obj if x is not None]
50
+
51
+ # Create collection_name for vector database
52
+ progress(0.1, desc="Creating collection name...")
53
+ collection_name = indexing.create_collection_name(list_file_path[0])
54
+
55
+ progress(0.25, desc="Loading document...")
56
+ # Load document and create splits
57
+ doc_splits = indexing.load_doc(list_file_path, chunk_size, chunk_overlap)
58
+
59
+ # Create or load vector database
60
+ progress(0.5, desc="Generating vector database...")
61
+
62
+ # global vector_db
63
+ vector_db = indexing.create_db(doc_splits, collection_name)
64
+
65
+ return vector_db, collection_name, "Complete!"
66
+
67
+
68
+ # Initialize LLM
69
+ def initialize_llm(
70
+ llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()
71
+ ):
72
+ """Initialize LLM"""
73
+
74
+ # print("llm_option",llm_option)
75
+ llm_name = list_llm[llm_option]
76
+ print("llm_name: ", llm_name)
77
+ qa_chain = retrieval.initialize_llmchain(
78
+ llm_name, huggingfacehub_api_token, llm_temperature, max_tokens, top_k, vector_db, progress
79
+ )
80
+ return qa_chain, "Complete!"
81
+
82
+
83
+ # Chatbot conversation
84
+ def conversation(qa_chain, message, history):
85
+ """Chatbot conversation"""
86
+
87
+ qa_chain, new_history, response_sources = retrieval.invoke_qa_chain(
88
+ qa_chain, message, history
89
+ )
90
+
91
+ # Format output gradio components
92
+ response_source1 = response_sources[0].page_content.strip()
93
+ response_source2 = response_sources[1].page_content.strip()
94
+ response_source3 = response_sources[2].page_content.strip()
95
+ # Langchain sources are zero-based
96
+ response_source1_page = response_sources[0].metadata["page"] + 1
97
+ response_source2_page = response_sources[1].metadata["page"] + 1
98
+ response_source3_page = response_sources[2].metadata["page"] + 1
99
+
100
+ return (
101
+ qa_chain,
102
+ gr.update(value=""),
103
+ new_history,
104
+ response_source1,
105
+ response_source1_page,
106
+ response_source2,
107
+ response_source2_page,
108
+ response_source3,
109
+ response_source3_page,
110
+ )
111
+
112
+
113
+ SPACE_TITLE = """
114
+ <center><h2>PDF-based chatbot</center></h2>
115
+ <h3>Ask any questions about your PDF documents</h3>
116
+ """
117
+
118
+ SPACE_INFO = """
119
+ <b>Description:</b> This AI assistant, using Langchain and open-source LLMs, performs retrieval-augmented generation (RAG) from your PDF documents. \
120
+ The user interface explicitely shows multiple steps to help understand the RAG workflow.
121
+ This chatbot takes past questions into account when generating answers (via conversational memory), and includes document references for clarity purposes.<br>
122
+ <br><b>Notes:</b> Updated space with more recent LLM models (Qwen 2.5, Llama 3.2, SmolLM2 series)
123
+ <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
124
+ """
125
+
126
+
127
+ # Gradio User Interface
128
+ def gradio_ui():
129
+ """Gradio User Interface"""
130
+
131
+ with gr.Blocks(theme="base") as demo:
132
+ vector_db = gr.State()
133
+ qa_chain = gr.State()
134
+ collection_name = gr.State()
135
+
136
+ gr.Markdown(SPACE_TITLE)
137
+ gr.Markdown(SPACE_INFO)
138
+
139
+ with gr.Tab("Step 1 - Upload PDF"):
140
+ with gr.Row():
141
+ document = gr.File(
142
+ height=200,
143
+ file_count="multiple",
144
+ file_types=[".pdf"],
145
+ interactive=True,
146
+ label="Upload your PDF documents (single or multiple)",
147
+ )
148
+
149
+ with gr.Tab("Step 2 - Process document"):
150
+ with gr.Row():
151
+ db_btn = gr.Radio(
152
+ ["ChromaDB"],
153
+ label="Vector database type",
154
+ value="ChromaDB",
155
+ type="index",
156
+ info="Choose your vector database",
157
+ )
158
+ with gr.Accordion("Advanced options - Document text splitter", open=False):
159
+ with gr.Row():
160
+ slider_chunk_size = gr.Slider(
161
+ minimum=100,
162
+ maximum=1000,
163
+ value=600,
164
+ step=20,
165
+ label="Chunk size",
166
+ info="Chunk size",
167
+ interactive=True,
168
+ )
169
+ with gr.Row():
170
+ slider_chunk_overlap = gr.Slider(
171
+ minimum=10,
172
+ maximum=200,
173
+ value=40,
174
+ step=10,
175
+ label="Chunk overlap",
176
+ info="Chunk overlap",
177
+ interactive=True,
178
+ )
179
+ with gr.Row():
180
+ db_progress = gr.Textbox(
181
+ label="Vector database initialization", value="None"
182
+ )
183
+ with gr.Row():
184
+ db_btn = gr.Button("Generate vector database")
185
+
186
+ with gr.Tab("Step 3 - Initialize QA chain"):
187
+ with gr.Row():
188
+ llm_btn = gr.Radio(
189
+ list_llm_simple,
190
+ label="LLM models",
191
+ value=list_llm_simple[6],
192
+ type="index",
193
+ info="Choose your LLM model",
194
+ )
195
+ with gr.Accordion("Advanced options - LLM model", open=False):
196
+ with gr.Row():
197
+ slider_temperature = gr.Slider(
198
+ minimum=0.01,
199
+ maximum=1.0,
200
+ value=0.7,
201
+ step=0.1,
202
+ label="Temperature",
203
+ info="Model temperature",
204
+ interactive=True,
205
+ )
206
+ with gr.Row():
207
+ slider_maxtokens = gr.Slider(
208
+ minimum=224,
209
+ maximum=4096,
210
+ value=1024,
211
+ step=32,
212
+ label="Max Tokens",
213
+ info="Model max tokens",
214
+ interactive=True,
215
+ )
216
+ with gr.Row():
217
+ slider_topk = gr.Slider(
218
+ minimum=1,
219
+ maximum=10,
220
+ value=3,
221
+ step=1,
222
+ label="top-k samples",
223
+ info="Model top-k samples",
224
+ interactive=True,
225
+ )
226
+ with gr.Row():
227
+ llm_progress = gr.Textbox(value="None", label="QA chain initialization")
228
+ with gr.Row():
229
+ qachain_btn = gr.Button("Initialize Question Answering chain")
230
+
231
+ with gr.Tab("Step 4 - Chatbot"):
232
+ chatbot = gr.Chatbot(height=300, type="tuples")
233
+ with gr.Accordion("Advanced - Document references", open=False):
234
+ with gr.Row():
235
+ doc_source1 = gr.Textbox(
236
+ label="Reference 1", lines=2, container=True, scale=20
237
+ )
238
+ source1_page = gr.Number(label="Page", scale=1)
239
+ with gr.Row():
240
+ doc_source2 = gr.Textbox(
241
+ label="Reference 2", lines=2, container=True, scale=20
242
+ )
243
+ source2_page = gr.Number(label="Page", scale=1)
244
+ with gr.Row():
245
+ doc_source3 = gr.Textbox(
246
+ label="Reference 3", lines=2, container=True, scale=20
247
+ )
248
+ source3_page = gr.Number(label="Page", scale=1)
249
+ with gr.Row():
250
+ msg = gr.Textbox(
251
+ placeholder="Type message (e.g. 'Can you summarize this document in one paragraph?')",
252
+ container=True,
253
+ )
254
+ with gr.Row():
255
+ submit_btn = gr.Button("Submit message")
256
+ clear_btn = gr.ClearButton(
257
+ components=[msg, chatbot], value="Clear conversation"
258
+ )
259
+
260
+ # Preprocessing events
261
+ db_btn.click(
262
+ initialize_database,
263
+ inputs=[document, slider_chunk_size, slider_chunk_overlap],
264
+ outputs=[vector_db, collection_name, db_progress],
265
+ )
266
+ qachain_btn.click(
267
+ initialize_llm,
268
+ inputs=[
269
+ llm_btn,
270
+ slider_temperature,
271
+ slider_maxtokens,
272
+ slider_topk,
273
+ vector_db,
274
+ ],
275
+ outputs=[qa_chain, llm_progress],
276
+ ).then(
277
+ lambda: [None, "", 0, "", 0, "", 0],
278
+ inputs=None,
279
+ outputs=[
280
+ chatbot,
281
+ doc_source1,
282
+ source1_page,
283
+ doc_source2,
284
+ source2_page,
285
+ doc_source3,
286
+ source3_page,
287
+ ],
288
+ queue=False,
289
+ )
290
+
291
+ # Chatbot events
292
+ msg.submit(
293
+ conversation,
294
+ inputs=[qa_chain, msg, chatbot],
295
+ outputs=[
296
+ qa_chain,
297
+ msg,
298
+ chatbot,
299
+ doc_source1,
300
+ source1_page,
301
+ doc_source2,
302
+ source2_page,
303
+ doc_source3,
304
+ source3_page,
305
+ ],
306
+ queue=False,
307
+ )
308
+ submit_btn.click(
309
+ conversation,
310
+ inputs=[qa_chain, msg, chatbot],
311
+ outputs=[
312
+ qa_chain,
313
+ msg,
314
+ chatbot,
315
+ doc_source1,
316
+ source1_page,
317
+ doc_source2,
318
+ source2_page,
319
+ doc_source3,
320
+ source3_page,
321
+ ],
322
+ queue=False,
323
+ )
324
+ clear_btn.click(
325
+ lambda: [None, "", 0, "", 0, "", 0],
326
+ inputs=None,
327
+ outputs=[
328
+ chatbot,
329
+ doc_source1,
330
+ source1_page,
331
+ doc_source2,
332
+ source2_page,
333
+ doc_source3,
334
+ source3_page,
335
+ ],
336
+ queue=False,
337
+ )
338
+ demo.queue().launch(debug=True)
339
+
340
+
341
+ if __name__ == "__main__":
342
+ retrieve_api()
343
+ gradio_ui()
requirements (1).txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers[torch]
2
+ sentence-transformers
3
+ langchain
4
+ langchain-community
5
+ langchain-huggingface
6
+ langchain-chroma
7
+ huggingface-hub
8
+ tqdm
9
+ accelerate
10
+ pypdf
11
+ chromadb
12
+ unidecode
13
+ gradio
14
+ python-dotenv
requirements-dev.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ pylint
2
+ black
retrieval.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM chain retrieval
3
+ """
4
+
5
+ import json
6
+ import gradio as gr
7
+
8
+ from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
9
+ from langchain.memory import ConversationBufferMemory
10
+ from langchain_huggingface import HuggingFaceEndpoint
11
+ from langchain_core.prompts import PromptTemplate
12
+
13
+
14
+ # Add system template for RAG application
15
+ PROMPT_TEMPLATE = """
16
+ You are an assistant for question-answering tasks. Use the following pieces of context to answer the question at the end.
17
+ If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer concise.
18
+ Question: {question}
19
+ Context: {context}
20
+ Helpful Answer:
21
+ """
22
+
23
+
24
+ # Initialize langchain LLM chain
25
+ def initialize_llmchain(
26
+ llm_model,
27
+ huggingfacehub_api_token,
28
+ temperature,
29
+ max_tokens,
30
+ top_k,
31
+ vector_db,
32
+ progress=gr.Progress(),
33
+ ):
34
+ """Initialize Langchain LLM chain"""
35
+
36
+ progress(0.1, desc="Initializing HF tokenizer...")
37
+ # HuggingFaceHub uses HF inference endpoints
38
+ progress(0.5, desc="Initializing HF Hub...")
39
+ # Use of trust_remote_code as model_kwargs
40
+ # Warning: langchain issue
41
+ # URL: https://github.com/langchain-ai/langchain/issues/6080
42
+
43
+ # if 'Llama' in llm_model:
44
+ # task = "conversational"
45
+ # else:
46
+ # task = "text-generation"
47
+ # print(f"Task: {task}")
48
+
49
+ llm = HuggingFaceEndpoint(
50
+ repo_id=llm_model,
51
+ task="text-generation",
52
+ #task="conversational",
53
+ provider="hf-inference",
54
+ temperature=temperature,
55
+ max_new_tokens=max_tokens,
56
+ top_k=top_k,
57
+ huggingfacehub_api_token=huggingfacehub_api_token,
58
+ )
59
+
60
+ progress(0.75, desc="Defining buffer memory...")
61
+ memory = ConversationBufferMemory(
62
+ memory_key="chat_history", output_key="answer", return_messages=True
63
+ )
64
+ # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
65
+ retriever = vector_db.as_retriever()
66
+
67
+ progress(0.8, desc="Defining retrieval chain...")
68
+ with open('prompt_template.json', 'r') as file:
69
+ system_prompt = json.load(file)
70
+ prompt_template = system_prompt["prompt"]
71
+ rag_prompt = PromptTemplate(
72
+ template=prompt_template, input_variables=["context", "question"]
73
+ )
74
+ qa_chain = ConversationalRetrievalChain.from_llm(
75
+ llm,
76
+ retriever=retriever,
77
+ chain_type="stuff",
78
+ memory=memory,
79
+ combine_docs_chain_kwargs={"prompt": rag_prompt},
80
+ return_source_documents=True,
81
+ # return_generated_question=False,
82
+ verbose=False,
83
+ )
84
+ progress(0.9, desc="Done!")
85
+
86
+ return qa_chain
87
+
88
+
89
+ def format_chat_history(message, chat_history):
90
+ """Format chat history for llm chain"""
91
+
92
+ formatted_chat_history = []
93
+ for user_message, bot_message in chat_history:
94
+ formatted_chat_history.append(f"User: {user_message}")
95
+ formatted_chat_history.append(f"Assistant: {bot_message}")
96
+ return formatted_chat_history
97
+
98
+
99
+ def invoke_qa_chain(qa_chain, message, history):
100
+ """Invoke question-answering chain"""
101
+
102
+ formatted_chat_history = format_chat_history(message, history)
103
+ # print("formatted_chat_history",formatted_chat_history)
104
+
105
+ # Generate response using QA chain
106
+ response = qa_chain.invoke(
107
+ {"question": message, "chat_history": formatted_chat_history}
108
+ )
109
+
110
+ response_sources = response["source_documents"]
111
+
112
+ response_answer = response["answer"]
113
+ if response_answer.find("Helpful Answer:") != -1:
114
+ response_answer = response_answer.split("Helpful Answer:")[-1]
115
+
116
+ # Append user message and response to chat history
117
+ new_history = history + [(message, response_answer)]
118
+
119
+ # print ('chat response: ', response_answer)
120
+ # print('DB source', response_sources)
121
+
122
+ return qa_chain, new_history, response_sources