arjunanand13 commited on
Commit
5ead45e
1 Parent(s): 82add23

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +305 -0
app.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from typing import List, Dict
4
+ import ragas
5
+ from ragas.metrics import (
6
+ context_relevancy,
7
+ faithfulness,
8
+ answer_relevancy,
9
+ context_recall
10
+ )
11
+ from datasets import load_dataset
12
+ from langchain.text_splitter import (
13
+ RecursiveCharacterTextSplitter,
14
+ CharacterTextSplitter,
15
+ SemanticTextSplitter
16
+ )
17
+ from langchain_community.vectorstores import FAISS, Chroma, Qdrant
18
+ from langchain_community.document_loaders import PyPDFLoader
19
+ from langchain.chains import ConversationalRetrievalChain
20
+ from langchain_community.embeddings import HuggingFaceEmbeddings
21
+ from langchain_community.llms import HuggingFaceEndpoint
22
+ from langchain.memory import ConversationBufferMemory
23
+ import torch
24
+
25
+ # Constants
26
+ list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
27
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
28
+ api_token = os.getenv("HF_TOKEN")
29
+
30
+ # Text splitting strategies
31
+ def get_text_splitter(strategy: str, chunk_size: int = 1024, chunk_overlap: int = 64):
32
+ splitters = {
33
+ "recursive": RecursiveCharacterTextSplitter(
34
+ chunk_size=chunk_size,
35
+ chunk_overlap=chunk_overlap
36
+ ),
37
+ "fixed": CharacterTextSplitter(
38
+ chunk_size=chunk_size,
39
+ chunk_overlap=chunk_overlap
40
+ ),
41
+ "semantic": SemanticTextSplitter(
42
+ embedding_function=HuggingFaceEmbeddings().embed_query,
43
+ chunk_size=chunk_size,
44
+ chunk_overlap=chunk_overlap
45
+ )
46
+ }
47
+ return splitters.get(strategy)
48
+
49
+ # Load and split PDF document
50
+ def load_doc(list_file_path: List[str], splitting_strategy: str = "recursive"):
51
+ loaders = [PyPDFLoader(x) for x in list_file_path]
52
+ pages = []
53
+ for loader in loaders:
54
+ pages.extend(loader.load())
55
+
56
+ text_splitter = get_text_splitter(splitting_strategy)
57
+ doc_splits = text_splitter.split_documents(pages)
58
+ return doc_splits
59
+
60
+ # Vector database creation functions
61
+ def create_faiss_db(splits, embeddings):
62
+ return FAISS.from_documents(splits, embeddings)
63
+
64
+ def create_chroma_db(splits, embeddings):
65
+ return Chroma.from_documents(splits, embeddings)
66
+
67
+ def create_qdrant_db(splits, embeddings):
68
+ return Qdrant.from_documents(
69
+ splits,
70
+ embeddings,
71
+ location=":memory:",
72
+ collection_name="pdf_docs"
73
+ )
74
+
75
+ def create_db(splits, db_choice: str = "faiss"):
76
+ embeddings = HuggingFaceEmbeddings()
77
+ db_creators = {
78
+ "faiss": create_faiss_db,
79
+ "chroma": create_chroma_db,
80
+ "qdrant": create_qdrant_db
81
+ }
82
+ return db_creators[db_choice](splits, embeddings)
83
+
84
+ # Evaluation functions
85
+ def load_evaluation_dataset():
86
+ # Load example dataset from RAGAS
87
+ dataset = load_dataset("explodinggradients/fiqa", split="test")
88
+ return dataset
89
+
90
+ def evaluate_rag_pipeline(qa_chain, dataset):
91
+ # Sample a few examples for evaluation
92
+ eval_samples = dataset.select(range(5))
93
+
94
+ results = {
95
+ "context_relevancy": [],
96
+ "faithfulness": [],
97
+ "answer_relevancy": [],
98
+ "context_recall": []
99
+ }
100
+
101
+ for sample in eval_samples:
102
+ question = sample["question"]
103
+ ground_truth = sample["answer"]
104
+
105
+ # Get response from the chain
106
+ response = qa_chain.invoke({
107
+ "question": question,
108
+ "chat_history": []
109
+ })
110
+
111
+ # Evaluate using RAGAS metrics
112
+ metrics = {
113
+ "context_relevancy": context_relevancy.score(
114
+ question=question,
115
+ answer=response["answer"],
116
+ contexts=response["source_documents"]
117
+ ),
118
+ "faithfulness": faithfulness.score(
119
+ question=question,
120
+ answer=response["answer"],
121
+ contexts=response["source_documents"]
122
+ ),
123
+ "answer_relevancy": answer_relevancy.score(
124
+ question=question,
125
+ answer=response["answer"]
126
+ ),
127
+ "context_recall": context_recall.score(
128
+ question=question,
129
+ answer=response["answer"],
130
+ contexts=response["source_documents"],
131
+ ground_truth=ground_truth
132
+ )
133
+ }
134
+
135
+ for metric, score in metrics.items():
136
+ results[metric].append(score)
137
+
138
+ # Calculate average scores
139
+ avg_results = {
140
+ metric: sum(scores) / len(scores)
141
+ for metric, scores in results.items()
142
+ }
143
+
144
+ return avg_results
145
+
146
+ # Initialize langchain LLM chain
147
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
148
+ if llm_model == "meta-llama/Meta-Llama-3-8B-Instruct":
149
+ llm = HuggingFaceEndpoint(
150
+ repo_id=llm_model,
151
+ huggingfacehub_api_token=api_token,
152
+ temperature=temperature,
153
+ max_new_tokens=max_tokens,
154
+ top_k=top_k,
155
+ )
156
+ else:
157
+ llm = HuggingFaceEndpoint(
158
+ huggingfacehub_api_token=api_token,
159
+ repo_id=llm_model,
160
+ temperature=temperature,
161
+ max_new_tokens=max_tokens,
162
+ top_k=top_k,
163
+ )
164
+
165
+ memory = ConversationBufferMemory(
166
+ memory_key="chat_history",
167
+ output_key='answer',
168
+ return_messages=True
169
+ )
170
+
171
+ retriever = vector_db.as_retriever()
172
+ qa_chain = ConversationalRetrievalChain.from_llm(
173
+ llm,
174
+ retriever=retriever,
175
+ chain_type="stuff",
176
+ memory=memory,
177
+ return_source_documents=True,
178
+ verbose=False,
179
+ )
180
+ return qa_chain
181
+
182
+ # Initialize database with chunking strategy and vector DB choice
183
+ def initialize_database(list_file_obj, splitting_strategy, db_choice, progress=gr.Progress()):
184
+ list_file_path = [x.name for x in list_file_obj if x is not None]
185
+ doc_splits = load_doc(list_file_path, splitting_strategy)
186
+ vector_db = create_db(doc_splits, db_choice)
187
+ return vector_db, f"Database created using {splitting_strategy} splitting and {db_choice} vector database!"
188
+
189
+ def demo():
190
+ with gr.Blocks(theme=gr.themes.Default(primary_hue="red", secondary_hue="pink", neutral_hue="sky")) as demo:
191
+ vector_db = gr.State()
192
+ qa_chain = gr.State()
193
+
194
+ gr.HTML("<center><h1>Enhanced RAG PDF Chatbot</h1></center>")
195
+ gr.Markdown("""<b>Query your PDF documents with advanced RAG capabilities!</b>""")
196
+
197
+ with gr.Row():
198
+ with gr.Column(scale=86):
199
+ gr.Markdown("<b>Step 1 - Configure and Initialize RAG Pipeline</b>")
200
+ with gr.Row():
201
+ document = gr.Files(height=300, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload PDF documents")
202
+
203
+ with gr.Row():
204
+ splitting_strategy = gr.Radio(
205
+ ["recursive", "fixed", "semantic"],
206
+ label="Text Splitting Strategy",
207
+ value="recursive"
208
+ )
209
+ db_choice = gr.Radio(
210
+ ["faiss", "chroma", "qdrant"],
211
+ label="Vector Database",
212
+ value="faiss"
213
+ )
214
+
215
+ with gr.Row():
216
+ db_btn = gr.Button("Create vector database")
217
+ evaluate_btn = gr.Button("Evaluate RAG Pipeline")
218
+
219
+ with gr.Row():
220
+ db_progress = gr.Textbox(value="Not initialized", show_label=False)
221
+ evaluation_results = gr.JSON(label="Evaluation Results")
222
+
223
+ gr.Markdown("<b>Select Large Language Model (LLM) and input parameters</b>")
224
+ with gr.Row():
225
+ llm_btn = gr.Radio(list_llm_simple, label="Available LLMs", value=list_llm_simple[0], type="index")
226
+
227
+ with gr.Row():
228
+ with gr.Accordion("LLM input parameters", open=False):
229
+ slider_temperature = gr.Slider(minimum=0.01, maximum=1.0, value=0.5, step=0.1, label="Temperature")
230
+ slider_maxtokens = gr.Slider(minimum=128, maximum=9192, value=4096, step=128, label="Max New Tokens")
231
+ slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k")
232
+
233
+ with gr.Row():
234
+ qachain_btn = gr.Button("Initialize Question Answering Chatbot")
235
+ llm_progress = gr.Textbox(value="Not initialized", show_label=False)
236
+
237
+ with gr.Column(scale=200):
238
+ gr.Markdown("<b>Step 2 - Chat with your Document</b>")
239
+ chatbot = gr.Chatbot(height=505)
240
+
241
+ with gr.Accordion("Relevant context from the source document", open=False):
242
+ with gr.Row():
243
+ doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
244
+ source1_page = gr.Number(label="Page", scale=1)
245
+ with gr.Row():
246
+ doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
247
+ source2_page = gr.Number(label="Page", scale=1)
248
+ with gr.Row():
249
+ doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
250
+ source3_page = gr.Number(label="Page", scale=1)
251
+
252
+ with gr.Row():
253
+ msg = gr.Textbox(placeholder="Ask a question", container=True)
254
+ with gr.Row():
255
+ submit_btn = gr.Button("Submit")
256
+ clear_btn = gr.ClearButton([msg, chatbot], value="Clear")
257
+
258
+ # Event handlers
259
+ db_btn.click(
260
+ initialize_database,
261
+ inputs=[document, splitting_strategy, db_choice],
262
+ outputs=[vector_db, db_progress]
263
+ )
264
+
265
+ evaluate_btn.click(
266
+ lambda qa_chain: evaluate_rag_pipeline(qa_chain, load_evaluation_dataset()) if qa_chain else None,
267
+ inputs=[qa_chain],
268
+ outputs=[evaluation_results]
269
+ )
270
+
271
+ qachain_btn.click(
272
+ initialize_LLM,
273
+ inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db],
274
+ outputs=[qa_chain, llm_progress]
275
+ ).then(
276
+ lambda: [None, "", 0, "", 0, "", 0],
277
+ inputs=None,
278
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
279
+ queue=False
280
+ )
281
+
282
+ # Chatbot event handlers remain the same
283
+ msg.submit(conversation,
284
+ inputs=[qa_chain, msg, chatbot],
285
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
286
+ queue=False
287
+ )
288
+
289
+ submit_btn.click(conversation,
290
+ inputs=[qa_chain, msg, chatbot],
291
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
292
+ queue=False
293
+ )
294
+
295
+ clear_btn.click(
296
+ lambda: [None, "", 0, "", 0, "", 0],
297
+ inputs=None,
298
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
299
+ queue=False
300
+ )
301
+
302
+ demo.queue().launch(debug=True)
303
+
304
+ if __name__ == "__main__":
305
+ demo()