File size: 10,154 Bytes
173b629 c2a45f4 173b629 8ac44b3 173b629 0908684 173b629 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import gradio as gr
# from dotenv import load_dotenv
# load_dotenv()
import warnings
warnings.filterwarnings("ignore")
import os, requests, shutil
from collections import defaultdict
from itertools import chain
from langchain.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.llms import HuggingFaceEndpoint
from langchain.storage import InMemoryStore
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.retrievers import ParentDocumentRetriever, BM25Retriever
from langchain.retrievers.document_compressors import LLMChainExtractor, LLMChainFilter, EmbeddingsFilter
from langchain_community.document_loaders import PyMuPDFLoader
from langchain.prompts import PromptTemplate
HF_READ_API_KEY = os.environ["HF_READ_API_KEY"]
def get_text(docs):
return ['Result ' + str(i+1) + '\n' + d.page_content + '\n' for i, d in enumerate(docs)]
def load_pdf(path):
loader = PyMuPDFLoader(path)
docs = loader.load()
return docs, 'PDF loaded successfully'
def multi_query_retrieval(query, llm, retriever):
DEFAULT_QUERY_PROMPT = PromptTemplate(
input_variables=["question"],
template="""You are an AI assistant. Generate 3 different versions of the given question to retrieve relevant docs.
Provide these alternative questions separated by newlines.
Original question: {question}""",
)
mq_llm_chain = LLMChain(llm=llm, prompt=DEFAULT_QUERY_PROMPT)
generated_queries = mq_llm_chain.invoke(query)['text'].split("\n")
all_queries = [query] + generated_queries
all_retrieved_docs = []
for q in all_queries:
retrieved_docs = retriever.get_relevant_documents(q)
all_retrieved_docs.extend(retrieved_docs)
unique_retrieved_docs = [doc for i, doc in enumerate(all_retrieved_docs) if doc not in all_retrieved_docs[:i]]
return get_text(unique_retrieved_docs)
def compressed_retrieval(query, llm, retriever, extractor_type='chain', embedding_model=None):
retrieved_docs = retriever.get_relevant_documents(query)
if extractor_type == 'chain':
extractor = LLMChainExtractor.from_llm(llm)
elif extractor_type == 'filter':
extractor = LLMChainFilter.from_llm(llm)
elif extractor_type == 'embeddings':
if embedding_model is None:
raise ValueError("Embeddings model must be provided for embeddings extractor.")
extractor = EmbeddingsFilter(embeddings=embedding_model, similarity_threshold=0.5)
else:
raise ValueError("Invalid extractor_type. Options are 'chain', 'filter', or 'embeddings'.")
compressed_docs = extractor.compress_documents(retrieved_docs, query)
return get_text(compressed_docs)
def unique_by_key(iterable, key_func):
seen = set()
for element in iterable:
key = key_func(element)
if key not in seen:
seen.add(key)
yield element
def ensemble_retrieval(query, retrievers_list, c=60):
retrieved_docs_by_retriever = [retriever.get_relevant_documents(query) for retriever in retrievers_list]
weights = [1 / len(retrievers_list)] * len(retrievers_list)
rrf_score = defaultdict(float)
for doc_list, weight in zip(retrieved_docs_by_retriever, weights):
for rank, doc in enumerate(doc_list, start=1):
rrf_score[doc.page_content] += weight / (rank + c)
all_docs = chain.from_iterable(retrieved_docs_by_retriever)
sorted_docs = sorted(
unique_by_key(all_docs, lambda doc: doc.page_content),
key=lambda doc: rrf_score[doc.page_content],
reverse=True
)
return get_text(sorted_docs)
def long_context_reorder_retrieval(query, retriever):
retrieved_docs = retriever.get_relevant_documents(query)
retrieved_docs.reverse()
reordered_results = []
for i, doc in enumerate(retrieved_docs):
if i % 2 == 1:
reordered_results.append(doc)
else:
reordered_results.insert(0, doc)
return get_text(reordered_results)
def process_query(docs, query, embedding_model, inference_model, retrieval_method, chunk_size, chunk_overlap, max_new_tokens, temperature, top_p):
chunking_parameters = {'chunk_size': chunk_size, 'chunk_overlap': chunk_overlap}
inference_model_params = {'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p}
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunking_parameters['chunk_size'], chunk_overlap=chunking_parameters['chunk_overlap'])
texts = text_splitter.split_documents(docs)
hf = HuggingFaceEmbeddings(model_name=embedding_model)
vector_db_from_docs = Chroma.from_documents(texts, hf)
simple_retriever = vector_db_from_docs.as_retriever(search_kwargs={"k": 5})
llm_model = HuggingFaceEndpoint(repo_id=inference_model,
max_new_tokens=inference_model_params['max_new_tokens'],
temperature=inference_model_params['temperature'],
top_p=inference_model_params['top_p'],
huggingfacehub_api_token=HF_READ_API_KEY)
if retrieval_method == "Simple":
retrieved_docs = simple_retriever.get_relevant_documents(query)
result = get_text(retrieved_docs)
elif retrieval_method == "Parent & Child":
parent_text_splitter = child_text_splitter = text_splitter
vector_db = Chroma(collection_name="parent_child", embedding_function=hf)
store = InMemoryStore()
pr_retriever = ParentDocumentRetriever(
vectorstore=vector_db,
docstore=store,
child_splitter=child_text_splitter,
parent_splitter=parent_text_splitter,
)
pr_retriever.add_documents(docs)
retrieved_docs = pr_retriever.get_relevant_documents(query)
result = get_text(retrieved_docs)
elif retrieval_method == "Multi Query":
result = multi_query_retrieval(query, llm_model, simple_retriever)
elif retrieval_method == "Contextual Compression (chain extraction)":
result = compressed_retrieval(query, llm_model, simple_retriever, extractor_type='chain')
elif retrieval_method == "Contextual Compression (query filter)":
result = compressed_retrieval(query, llm_model, simple_retriever, extractor_type='filter')
elif retrieval_method == "Contextual Compression (embeddings filter)":
result = compressed_retrieval(query, llm_model, simple_retriever, extractor_type='embeddings', embedding_model=hf)
elif retrieval_method == "Ensemble":
bm25_retriever = BM25Retriever.from_documents(docs)
all_retrievers = [simple_retriever, bm25_retriever]
result = ensemble_retrieval(query, all_retrievers)
elif retrieval_method == "Long Context Reorder":
result = long_context_reorder_retrieval(query, simple_retriever)
else:
raise ValueError(f"Unknown retrieval method: {retrieval_method}")
prompt_template = PromptTemplate.from_template(
"Answer the query {query} with the following context:\n {context}. If you cannot use the context to answer the query, say 'I cannot answer the query with the provided context.'"
)
answer = llm_model.invoke(prompt_template.format(query=query, context=result))
return "\n".join(result), answer.strip()
embedding_model_list = ['sentence-transformers/all-MiniLM-L6-v2', 'BAAI/bge-small-en-v1.5', 'BAAI/bge-large-en-v1.5']
inference_model_list = ['google/gemma-2b-it', 'google/gemma-7b-it', 'microsoft/phi-2', 'mistralai/Mixtral-8x7B-Instruct-v0.1', 'mistralai/Mistral-7B-Instruct-v0.2']
retrieval_method_list = ["Simple", "Parent & Child", "Multi Query",
"Contextual Compression (chain extraction)", "Contextual Compression (query filter)",
"Contextual Compression (embeddings filter)", "Ensemble", "Long Context Reorder"]
with gr.Blocks() as demo:
gr.Markdown("## Compare Retrieval Methods for PDFs")
with gr.Row():
with gr.Column():
pdf_url = gr.Textbox(label="Enter URL to PDF", value="https://www.berkshirehathaway.com/letters/2023ltr.pdf")
load_button = gr.Button("Load and process PDF")
status = gr.Textbox(label="Status")
docs = gr.State()
load_button.click(load_pdf, inputs=[pdf_url], outputs=[docs, status])
query = gr.Textbox(label="Enter your query", value="What does Warren Buffet think about Coca Cola?")
with gr.Row():
embedding_model = gr.Dropdown(embedding_model_list, label="Select Embedding Model", value=embedding_model_list[0])
inference_model = gr.Dropdown(inference_model_list, label="Select Inference Model", value=inference_model_list[3])
retrieval_method = gr.Dropdown(retrieval_method_list, label="Select Retrieval Method", value=retrieval_method_list[0])
with gr.Row():
chunk_size = gr.Number(label="Chunk Size", value=1000)
chunk_overlap = gr.Number(label="Chunk Overlap", value=200)
with gr.Row():
max_new_tokens = gr.Number(label="Max New Tokens", value=100)
temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="Temperature", value=0.7)
top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="Top P", value=0.9)
search_button = gr.Button("Retrieval")
with gr.Column():
answer = gr.Textbox(label="Answer")
retrieval_output = gr.Textbox(label="Retrieval Results")
search_button.click(process_query, inputs=[docs, query, embedding_model, inference_model, retrieval_method, chunk_size, chunk_overlap, max_new_tokens, temperature, top_p], outputs=[retrieval_output, answer])
if __name__ == "__main__":
demo.launch() |