playgrdstar commited on
Commit
173b629
·
1 Parent(s): 131c5e4

First commit

Browse files
Files changed (2) hide show
  1. app.py +211 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from dotenv import load_dotenv
3
+ load_dotenv()
4
+
5
+ import warnings
6
+ warnings.filterwarnings("ignore")
7
+
8
+ import os, requests, git, shutil
9
+ from collections import defaultdict
10
+ from itertools import chain
11
+
12
+ from langchain.document_loaders import TextLoader
13
+ from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
14
+ from langchain.embeddings import HuggingFaceEmbeddings
15
+ from langchain.vectorstores import Chroma
16
+ from langchain.llms import HuggingFaceEndpoint
17
+ from langchain.storage import InMemoryStore
18
+ from langchain.chains import LLMChain
19
+ from langchain.prompts import PromptTemplate
20
+ from langchain.retrievers import ParentDocumentRetriever, BM25Retriever
21
+ from langchain.retrievers.document_compressors import LLMChainExtractor, LLMChainFilter, EmbeddingsFilter
22
+ from langchain_community.document_loaders import PyMuPDFLoader
23
+ from langchain.prompts import PromptTemplate
24
+
25
+
26
+ HF_READ_API_KEY = os.environ["HF_READ_API_KEY"]
27
+
28
+ def get_text(docs):
29
+ return ['Result ' + str(i+1) + '\n' + d.page_content + '\n' for i, d in enumerate(docs)]
30
+
31
+ def load_pdf(path):
32
+ loader = PyMuPDFLoader(path)
33
+ docs = loader.load()
34
+
35
+ return docs, 'PDF loaded successfully'
36
+
37
+
38
+ def multi_query_retrieval(query, llm, retriever):
39
+ DEFAULT_QUERY_PROMPT = PromptTemplate(
40
+ input_variables=["question"],
41
+ template="""You are an AI assistant. Generate 3 different versions of the given question to retrieve relevant docs.
42
+ Provide these alternative questions separated by newlines.
43
+ Original question: {question}""",
44
+ )
45
+ mq_llm_chain = LLMChain(llm=llm, prompt=DEFAULT_QUERY_PROMPT)
46
+
47
+ generated_queries = mq_llm_chain.invoke(query)['text'].split("\n")
48
+ all_queries = [query] + generated_queries
49
+
50
+ all_retrieved_docs = []
51
+ for q in all_queries:
52
+ retrieved_docs = retriever.get_relevant_documents(q)
53
+ all_retrieved_docs.extend(retrieved_docs)
54
+
55
+ unique_retrieved_docs = [doc for i, doc in enumerate(all_retrieved_docs) if doc not in all_retrieved_docs[:i]]
56
+
57
+ return get_text(unique_retrieved_docs)
58
+
59
+ def compressed_retrieval(query, llm, retriever, extractor_type='chain', embedding_model=None):
60
+ retrieved_docs = retriever.get_relevant_documents(query)
61
+ if extractor_type == 'chain':
62
+ extractor = LLMChainExtractor.from_llm(llm)
63
+ elif extractor_type == 'filter':
64
+ extractor = LLMChainFilter.from_llm(llm)
65
+ elif extractor_type == 'embeddings':
66
+ if embedding_model is None:
67
+ raise ValueError("Embeddings model must be provided for embeddings extractor.")
68
+ extractor = EmbeddingsFilter(embeddings=embedding_model, similarity_threshold=0.5)
69
+ else:
70
+ raise ValueError("Invalid extractor_type. Options are 'chain', 'filter', or 'embeddings'.")
71
+ compressed_docs = extractor.compress_documents(retrieved_docs, query)
72
+ return get_text(compressed_docs)
73
+
74
+ def unique_by_key(iterable, key_func):
75
+ seen = set()
76
+ for element in iterable:
77
+ key = key_func(element)
78
+ if key not in seen:
79
+ seen.add(key)
80
+ yield element
81
+
82
+ def ensemble_retrieval(query, retrievers_list, c=60):
83
+ retrieved_docs_by_retriever = [retriever.get_relevant_documents(query) for retriever in retrievers_list]
84
+ weights = [1 / len(retrievers_list)] * len(retrievers_list)
85
+ rrf_score = defaultdict(float)
86
+ for doc_list, weight in zip(retrieved_docs_by_retriever, weights):
87
+ for rank, doc in enumerate(doc_list, start=1):
88
+ rrf_score[doc.page_content] += weight / (rank + c)
89
+
90
+ all_docs = chain.from_iterable(retrieved_docs_by_retriever)
91
+ sorted_docs = sorted(
92
+ unique_by_key(all_docs, lambda doc: doc.page_content),
93
+ key=lambda doc: rrf_score[doc.page_content],
94
+ reverse=True
95
+ )
96
+ return get_text(sorted_docs)
97
+
98
+ def long_context_reorder_retrieval(query, retriever):
99
+ retrieved_docs = retriever.get_relevant_documents(query)
100
+ retrieved_docs.reverse()
101
+ reordered_results = []
102
+ for i, doc in enumerate(retrieved_docs):
103
+ if i % 2 == 1:
104
+ reordered_results.append(doc)
105
+ else:
106
+ reordered_results.insert(0, doc)
107
+ return get_text(reordered_results)
108
+
109
+ def process_query(docs, query, embedding_model, inference_model, retrieval_method, chunk_size, chunk_overlap, max_new_tokens, temperature, top_p):
110
+
111
+
112
+ chunking_parameters = {'chunk_size': chunk_size, 'chunk_overlap': chunk_overlap}
113
+ inference_model_params = {'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p}
114
+
115
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunking_parameters['chunk_size'], chunk_overlap=chunking_parameters['chunk_overlap'])
116
+
117
+ texts = text_splitter.split_documents(docs)
118
+
119
+ hf = HuggingFaceEmbeddings(model_name=embedding_model)
120
+ vector_db_from_docs = Chroma.from_documents(texts, hf)
121
+ simple_retriever = vector_db_from_docs.as_retriever(search_kwargs={"k": 5})
122
+
123
+ llm_model = HuggingFaceEndpoint(repo_id=inference_model,
124
+ max_new_tokens=inference_model_params['max_new_tokens'],
125
+ temperature=inference_model_params['temperature'],
126
+ top_p=inference_model_params['top_p'],
127
+ huggingfacehub_api_token=HF_READ_API_KEY)
128
+
129
+ if retrieval_method == "Simple":
130
+ retrieved_docs = simple_retriever.get_relevant_documents(query)
131
+ result = get_text(retrieved_docs)
132
+ elif retrieval_method == "Parent & Child":
133
+ parent_text_splitter = child_text_splitter = text_splitter
134
+ vector_db = Chroma(collection_name="parent_child", embedding_function=hf)
135
+ store = InMemoryStore()
136
+ pr_retriever = ParentDocumentRetriever(
137
+ vectorstore=vector_db,
138
+ docstore=store,
139
+ child_splitter=child_text_splitter,
140
+ parent_splitter=parent_text_splitter,
141
+ )
142
+ pr_retriever.add_documents(docs)
143
+ retrieved_docs = pr_retriever.get_relevant_documents(query)
144
+ result = get_text(retrieved_docs)
145
+ elif retrieval_method == "Multi Query":
146
+ result = multi_query_retrieval(query, llm_model, simple_retriever)
147
+ elif retrieval_method == "Contextual Compression (chain extraction)":
148
+ result = compressed_retrieval(query, llm_model, simple_retriever, extractor_type='chain')
149
+ elif retrieval_method == "Contextual Compression (query filter)":
150
+ result = compressed_retrieval(query, llm_model, simple_retriever, extractor_type='filter')
151
+ elif retrieval_method == "Contextual Compression (embeddings filter)":
152
+ result = compressed_retrieval(query, llm_model, simple_retriever, extractor_type='embeddings', embedding_model=hf)
153
+ elif retrieval_method == "Ensemble":
154
+ bm25_retriever = BM25Retriever.from_documents(docs)
155
+ all_retrievers = [simple_retriever, bm25_retriever]
156
+ result = ensemble_retrieval(query, all_retrievers)
157
+ elif retrieval_method == "Long Context Reorder":
158
+ result = long_context_reorder_retrieval(query, simple_retriever)
159
+ else:
160
+ raise ValueError(f"Unknown retrieval method: {retrieval_method}")
161
+
162
+
163
+ prompt_template = PromptTemplate.from_template(
164
+ "Answer the query {query} with the following context:\n {context}. If you cannot use the context to answer the query, say 'I cannot answer the query with the provided context.'"
165
+ )
166
+
167
+ answer = llm_model.invoke(prompt_template.format(query=query, context=result))
168
+
169
+ return "\n".join(result), answer.strip()
170
+
171
+ embedding_model_list = ['sentence-transformers/all-MiniLM-L6-v2', 'BAAI/bge-small-en-v1.5', 'BAAI/bge-large-en-v1.5']
172
+ inference_model_list = ['google/gemma-2b-it', 'google/gemma-7b-it', 'microsoft/phi-2', 'mistralai/Mixtral-8x7B-Instruct-v0.1', 'mistralai/Mistral-7B-Instruct-v0.2']
173
+ retrieval_method_list = ["Simple", "Parent & Child", "Multi Query",
174
+ "Contextual Compression (chain extraction)", "Contextual Compression (query filter)",
175
+ "Contextual Compression (embeddings filter)", "Ensemble", "Long Context Reorder"]
176
+
177
+
178
+ with gr.Blocks() as demo:
179
+ gr.Markdown("## Compare Retrieval Methods for PDFs")
180
+ with gr.Row():
181
+ with gr.Column():
182
+ pdf_url = gr.Textbox(label="Enter URL to PDF", value="https://www.berkshirehathaway.com/letters/2023ltr.pdf")
183
+ load_button = gr.Button("Load and process PDF")
184
+ status = gr.Textbox(label="Status")
185
+ docs = gr.State()
186
+ load_button.click(load_pdf, inputs=[pdf_url], outputs=[docs, status])
187
+
188
+ query = gr.Textbox(label="Enter your query", value="What does Warren Buffet think about Coca Cola?")
189
+ with gr.Row():
190
+ embedding_model = gr.Dropdown(embedding_model_list, label="Select Embedding Model", value=embedding_model_list[0])
191
+ inference_model = gr.Dropdown(inference_model_list, label="Select Inference Model", value=inference_model_list[0])
192
+ retrieval_method = gr.Dropdown(retrieval_method_list, label="Select Retrieval Method", value=retrieval_method_list[0])
193
+
194
+ with gr.Row():
195
+ chunk_size = gr.Number(label="Chunk Size", value=1000)
196
+ chunk_overlap = gr.Number(label="Chunk Overlap", value=200)
197
+
198
+ with gr.Row():
199
+ max_new_tokens = gr.Number(label="Max New Tokens", value=100)
200
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="Temperature", value=0.7)
201
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="Top P", value=0.9)
202
+
203
+ search_button = gr.Button("Retrieval")
204
+ with gr.Column():
205
+ answer = gr.Textbox(label="Answer")
206
+ retrieval_output = gr.Textbox(label="Retrieval Results")
207
+
208
+ search_button.click(process_query, inputs=[docs, query, embedding_model, inference_model, retrieval_method, chunk_size, chunk_overlap, max_new_tokens, temperature, top_p], outputs=[retrieval_output, answer])
209
+
210
+ if __name__ == "__main__":
211
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ langchain
2
+ pypdf