Commit
·
173b629
1
Parent(s):
131c5e4
First commit
Browse files- app.py +211 -0
- requirements.txt +2 -0
app.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
load_dotenv()
|
4 |
+
|
5 |
+
import warnings
|
6 |
+
warnings.filterwarnings("ignore")
|
7 |
+
|
8 |
+
import os, requests, git, shutil
|
9 |
+
from collections import defaultdict
|
10 |
+
from itertools import chain
|
11 |
+
|
12 |
+
from langchain.document_loaders import TextLoader
|
13 |
+
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
|
14 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
15 |
+
from langchain.vectorstores import Chroma
|
16 |
+
from langchain.llms import HuggingFaceEndpoint
|
17 |
+
from langchain.storage import InMemoryStore
|
18 |
+
from langchain.chains import LLMChain
|
19 |
+
from langchain.prompts import PromptTemplate
|
20 |
+
from langchain.retrievers import ParentDocumentRetriever, BM25Retriever
|
21 |
+
from langchain.retrievers.document_compressors import LLMChainExtractor, LLMChainFilter, EmbeddingsFilter
|
22 |
+
from langchain_community.document_loaders import PyMuPDFLoader
|
23 |
+
from langchain.prompts import PromptTemplate
|
24 |
+
|
25 |
+
|
26 |
+
HF_READ_API_KEY = os.environ["HF_READ_API_KEY"]
|
27 |
+
|
28 |
+
def get_text(docs):
|
29 |
+
return ['Result ' + str(i+1) + '\n' + d.page_content + '\n' for i, d in enumerate(docs)]
|
30 |
+
|
31 |
+
def load_pdf(path):
|
32 |
+
loader = PyMuPDFLoader(path)
|
33 |
+
docs = loader.load()
|
34 |
+
|
35 |
+
return docs, 'PDF loaded successfully'
|
36 |
+
|
37 |
+
|
38 |
+
def multi_query_retrieval(query, llm, retriever):
|
39 |
+
DEFAULT_QUERY_PROMPT = PromptTemplate(
|
40 |
+
input_variables=["question"],
|
41 |
+
template="""You are an AI assistant. Generate 3 different versions of the given question to retrieve relevant docs.
|
42 |
+
Provide these alternative questions separated by newlines.
|
43 |
+
Original question: {question}""",
|
44 |
+
)
|
45 |
+
mq_llm_chain = LLMChain(llm=llm, prompt=DEFAULT_QUERY_PROMPT)
|
46 |
+
|
47 |
+
generated_queries = mq_llm_chain.invoke(query)['text'].split("\n")
|
48 |
+
all_queries = [query] + generated_queries
|
49 |
+
|
50 |
+
all_retrieved_docs = []
|
51 |
+
for q in all_queries:
|
52 |
+
retrieved_docs = retriever.get_relevant_documents(q)
|
53 |
+
all_retrieved_docs.extend(retrieved_docs)
|
54 |
+
|
55 |
+
unique_retrieved_docs = [doc for i, doc in enumerate(all_retrieved_docs) if doc not in all_retrieved_docs[:i]]
|
56 |
+
|
57 |
+
return get_text(unique_retrieved_docs)
|
58 |
+
|
59 |
+
def compressed_retrieval(query, llm, retriever, extractor_type='chain', embedding_model=None):
|
60 |
+
retrieved_docs = retriever.get_relevant_documents(query)
|
61 |
+
if extractor_type == 'chain':
|
62 |
+
extractor = LLMChainExtractor.from_llm(llm)
|
63 |
+
elif extractor_type == 'filter':
|
64 |
+
extractor = LLMChainFilter.from_llm(llm)
|
65 |
+
elif extractor_type == 'embeddings':
|
66 |
+
if embedding_model is None:
|
67 |
+
raise ValueError("Embeddings model must be provided for embeddings extractor.")
|
68 |
+
extractor = EmbeddingsFilter(embeddings=embedding_model, similarity_threshold=0.5)
|
69 |
+
else:
|
70 |
+
raise ValueError("Invalid extractor_type. Options are 'chain', 'filter', or 'embeddings'.")
|
71 |
+
compressed_docs = extractor.compress_documents(retrieved_docs, query)
|
72 |
+
return get_text(compressed_docs)
|
73 |
+
|
74 |
+
def unique_by_key(iterable, key_func):
|
75 |
+
seen = set()
|
76 |
+
for element in iterable:
|
77 |
+
key = key_func(element)
|
78 |
+
if key not in seen:
|
79 |
+
seen.add(key)
|
80 |
+
yield element
|
81 |
+
|
82 |
+
def ensemble_retrieval(query, retrievers_list, c=60):
|
83 |
+
retrieved_docs_by_retriever = [retriever.get_relevant_documents(query) for retriever in retrievers_list]
|
84 |
+
weights = [1 / len(retrievers_list)] * len(retrievers_list)
|
85 |
+
rrf_score = defaultdict(float)
|
86 |
+
for doc_list, weight in zip(retrieved_docs_by_retriever, weights):
|
87 |
+
for rank, doc in enumerate(doc_list, start=1):
|
88 |
+
rrf_score[doc.page_content] += weight / (rank + c)
|
89 |
+
|
90 |
+
all_docs = chain.from_iterable(retrieved_docs_by_retriever)
|
91 |
+
sorted_docs = sorted(
|
92 |
+
unique_by_key(all_docs, lambda doc: doc.page_content),
|
93 |
+
key=lambda doc: rrf_score[doc.page_content],
|
94 |
+
reverse=True
|
95 |
+
)
|
96 |
+
return get_text(sorted_docs)
|
97 |
+
|
98 |
+
def long_context_reorder_retrieval(query, retriever):
|
99 |
+
retrieved_docs = retriever.get_relevant_documents(query)
|
100 |
+
retrieved_docs.reverse()
|
101 |
+
reordered_results = []
|
102 |
+
for i, doc in enumerate(retrieved_docs):
|
103 |
+
if i % 2 == 1:
|
104 |
+
reordered_results.append(doc)
|
105 |
+
else:
|
106 |
+
reordered_results.insert(0, doc)
|
107 |
+
return get_text(reordered_results)
|
108 |
+
|
109 |
+
def process_query(docs, query, embedding_model, inference_model, retrieval_method, chunk_size, chunk_overlap, max_new_tokens, temperature, top_p):
|
110 |
+
|
111 |
+
|
112 |
+
chunking_parameters = {'chunk_size': chunk_size, 'chunk_overlap': chunk_overlap}
|
113 |
+
inference_model_params = {'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p}
|
114 |
+
|
115 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunking_parameters['chunk_size'], chunk_overlap=chunking_parameters['chunk_overlap'])
|
116 |
+
|
117 |
+
texts = text_splitter.split_documents(docs)
|
118 |
+
|
119 |
+
hf = HuggingFaceEmbeddings(model_name=embedding_model)
|
120 |
+
vector_db_from_docs = Chroma.from_documents(texts, hf)
|
121 |
+
simple_retriever = vector_db_from_docs.as_retriever(search_kwargs={"k": 5})
|
122 |
+
|
123 |
+
llm_model = HuggingFaceEndpoint(repo_id=inference_model,
|
124 |
+
max_new_tokens=inference_model_params['max_new_tokens'],
|
125 |
+
temperature=inference_model_params['temperature'],
|
126 |
+
top_p=inference_model_params['top_p'],
|
127 |
+
huggingfacehub_api_token=HF_READ_API_KEY)
|
128 |
+
|
129 |
+
if retrieval_method == "Simple":
|
130 |
+
retrieved_docs = simple_retriever.get_relevant_documents(query)
|
131 |
+
result = get_text(retrieved_docs)
|
132 |
+
elif retrieval_method == "Parent & Child":
|
133 |
+
parent_text_splitter = child_text_splitter = text_splitter
|
134 |
+
vector_db = Chroma(collection_name="parent_child", embedding_function=hf)
|
135 |
+
store = InMemoryStore()
|
136 |
+
pr_retriever = ParentDocumentRetriever(
|
137 |
+
vectorstore=vector_db,
|
138 |
+
docstore=store,
|
139 |
+
child_splitter=child_text_splitter,
|
140 |
+
parent_splitter=parent_text_splitter,
|
141 |
+
)
|
142 |
+
pr_retriever.add_documents(docs)
|
143 |
+
retrieved_docs = pr_retriever.get_relevant_documents(query)
|
144 |
+
result = get_text(retrieved_docs)
|
145 |
+
elif retrieval_method == "Multi Query":
|
146 |
+
result = multi_query_retrieval(query, llm_model, simple_retriever)
|
147 |
+
elif retrieval_method == "Contextual Compression (chain extraction)":
|
148 |
+
result = compressed_retrieval(query, llm_model, simple_retriever, extractor_type='chain')
|
149 |
+
elif retrieval_method == "Contextual Compression (query filter)":
|
150 |
+
result = compressed_retrieval(query, llm_model, simple_retriever, extractor_type='filter')
|
151 |
+
elif retrieval_method == "Contextual Compression (embeddings filter)":
|
152 |
+
result = compressed_retrieval(query, llm_model, simple_retriever, extractor_type='embeddings', embedding_model=hf)
|
153 |
+
elif retrieval_method == "Ensemble":
|
154 |
+
bm25_retriever = BM25Retriever.from_documents(docs)
|
155 |
+
all_retrievers = [simple_retriever, bm25_retriever]
|
156 |
+
result = ensemble_retrieval(query, all_retrievers)
|
157 |
+
elif retrieval_method == "Long Context Reorder":
|
158 |
+
result = long_context_reorder_retrieval(query, simple_retriever)
|
159 |
+
else:
|
160 |
+
raise ValueError(f"Unknown retrieval method: {retrieval_method}")
|
161 |
+
|
162 |
+
|
163 |
+
prompt_template = PromptTemplate.from_template(
|
164 |
+
"Answer the query {query} with the following context:\n {context}. If you cannot use the context to answer the query, say 'I cannot answer the query with the provided context.'"
|
165 |
+
)
|
166 |
+
|
167 |
+
answer = llm_model.invoke(prompt_template.format(query=query, context=result))
|
168 |
+
|
169 |
+
return "\n".join(result), answer.strip()
|
170 |
+
|
171 |
+
embedding_model_list = ['sentence-transformers/all-MiniLM-L6-v2', 'BAAI/bge-small-en-v1.5', 'BAAI/bge-large-en-v1.5']
|
172 |
+
inference_model_list = ['google/gemma-2b-it', 'google/gemma-7b-it', 'microsoft/phi-2', 'mistralai/Mixtral-8x7B-Instruct-v0.1', 'mistralai/Mistral-7B-Instruct-v0.2']
|
173 |
+
retrieval_method_list = ["Simple", "Parent & Child", "Multi Query",
|
174 |
+
"Contextual Compression (chain extraction)", "Contextual Compression (query filter)",
|
175 |
+
"Contextual Compression (embeddings filter)", "Ensemble", "Long Context Reorder"]
|
176 |
+
|
177 |
+
|
178 |
+
with gr.Blocks() as demo:
|
179 |
+
gr.Markdown("## Compare Retrieval Methods for PDFs")
|
180 |
+
with gr.Row():
|
181 |
+
with gr.Column():
|
182 |
+
pdf_url = gr.Textbox(label="Enter URL to PDF", value="https://www.berkshirehathaway.com/letters/2023ltr.pdf")
|
183 |
+
load_button = gr.Button("Load and process PDF")
|
184 |
+
status = gr.Textbox(label="Status")
|
185 |
+
docs = gr.State()
|
186 |
+
load_button.click(load_pdf, inputs=[pdf_url], outputs=[docs, status])
|
187 |
+
|
188 |
+
query = gr.Textbox(label="Enter your query", value="What does Warren Buffet think about Coca Cola?")
|
189 |
+
with gr.Row():
|
190 |
+
embedding_model = gr.Dropdown(embedding_model_list, label="Select Embedding Model", value=embedding_model_list[0])
|
191 |
+
inference_model = gr.Dropdown(inference_model_list, label="Select Inference Model", value=inference_model_list[0])
|
192 |
+
retrieval_method = gr.Dropdown(retrieval_method_list, label="Select Retrieval Method", value=retrieval_method_list[0])
|
193 |
+
|
194 |
+
with gr.Row():
|
195 |
+
chunk_size = gr.Number(label="Chunk Size", value=1000)
|
196 |
+
chunk_overlap = gr.Number(label="Chunk Overlap", value=200)
|
197 |
+
|
198 |
+
with gr.Row():
|
199 |
+
max_new_tokens = gr.Number(label="Max New Tokens", value=100)
|
200 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="Temperature", value=0.7)
|
201 |
+
top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="Top P", value=0.9)
|
202 |
+
|
203 |
+
search_button = gr.Button("Retrieval")
|
204 |
+
with gr.Column():
|
205 |
+
answer = gr.Textbox(label="Answer")
|
206 |
+
retrieval_output = gr.Textbox(label="Retrieval Results")
|
207 |
+
|
208 |
+
search_button.click(process_query, inputs=[docs, query, embedding_model, inference_model, retrieval_method, chunk_size, chunk_overlap, max_new_tokens, temperature, top_p], outputs=[retrieval_output, answer])
|
209 |
+
|
210 |
+
if __name__ == "__main__":
|
211 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
langchain
|
2 |
+
pypdf
|