Spaces:
Sleeping
Sleeping
Commit
·
966108f
1
Parent(s):
462639d
Initial commit
Browse files- .gitattributes +2 -0
- app.py +80 -0
- benchmark.py +47 -0
- main.py +33 -0
- retrieval_pipeline/__init__.py +2 -0
- retrieval_pipeline/hybrid_search.py +95 -0
- retrieval_pipeline/main.py +98 -0
- retrieval_pipeline/utils.py +5 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
multilingual-e5-large filter=lfs diff=lfs merge=lfs -text
|
37 |
+
multilingual-e5-large/* filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
import json
|
4 |
+
import os, time
|
5 |
+
import uuid
|
6 |
+
|
7 |
+
from retrieval_pipeline import get_retriever, get_compression_retriever
|
8 |
+
import benchmark
|
9 |
+
|
10 |
+
|
11 |
+
def get_result(query, compression_retriever):
|
12 |
+
t0 = time.time()
|
13 |
+
retrieved_chunks = compression_retriever.get_relevant_documents(query)
|
14 |
+
latency = time.time() - t0
|
15 |
+
return retrieved_chunks, latency
|
16 |
+
|
17 |
+
st.set_page_config(
|
18 |
+
layout="wide",
|
19 |
+
page_title="Retrieval Demo"
|
20 |
+
)
|
21 |
+
|
22 |
+
def setup():
|
23 |
+
load_dotenv()
|
24 |
+
ELASTICSEARCH_URL = os.getenv('ELASTICSEARCH_URL')
|
25 |
+
|
26 |
+
retriever = get_retriever(index='masa.ai', elasticsearch_url=ELASTICSEARCH_URL)
|
27 |
+
compression_retriever = get_compression_retriever(retriever)
|
28 |
+
return compression_retriever
|
29 |
+
|
30 |
+
|
31 |
+
def main():
|
32 |
+
st.title("Part 3: Search")
|
33 |
+
# st.sidebar.write("According to the Model Size 👇")
|
34 |
+
# menu = ["Nano", "Small", "Medium", "Large"]
|
35 |
+
# choice = st.sidebar.selectbox("Choose", menu)
|
36 |
+
|
37 |
+
st.sidebar.info("""
|
38 |
+
**Model Options:**
|
39 |
+
- **Nano**: ~4MB, blazing fast model with competitive performance (ranking precision).
|
40 |
+
- **Small**: ~34MB, slightly slower with the best performance (ranking precision).
|
41 |
+
- **Medium**: ~110MB, slower model with the best zero-shot performance (ranking precision).
|
42 |
+
- **Large**: ~150MB, slower model with competitive performance (ranking precision) for 100+ languages.
|
43 |
+
""")
|
44 |
+
|
45 |
+
with st.spinner('Setting up...'):
|
46 |
+
compression_retriever = setup()
|
47 |
+
|
48 |
+
with st.expander("Tech Stack Used"):
|
49 |
+
st.markdown("""
|
50 |
+
**Flash Rank**: Ultra-lite & Super-fast Python library for search & retrieval re-ranking.
|
51 |
+
|
52 |
+
- **Ultra-lite**: No heavy dependencies. Runs on CPU with a tiny ~4MB reranking model.
|
53 |
+
- **Super-fast**: Speed depends on the number of tokens in passages and query, plus model depth.
|
54 |
+
- **Cost-efficient**: Ideal for serverless deployments with low memory and time requirements.
|
55 |
+
- **Based on State-of-the-Art Cross-encoders**: Includes models like ms-marco-TinyBERT-L-2-v2 (default), ms-marco-MiniLM-L-12-v2, rank-T5-flan, and ms-marco-MultiBERT-L-12.
|
56 |
+
- **Sleek Models for Efficiency**: Designed for minimal overhead in user-facing scenarios.
|
57 |
+
|
58 |
+
_Flash Rank is tailored for scenarios requiring efficient and effective reranking, balancing performance with resource usage._
|
59 |
+
""")
|
60 |
+
|
61 |
+
|
62 |
+
with st.form(key='input_form'):
|
63 |
+
query_input = st.text_area("Query Input")
|
64 |
+
# context_input = st.text_area("Context Input")
|
65 |
+
submit_button = st.form_submit_button(label='Retrieve')
|
66 |
+
|
67 |
+
if submit_button:
|
68 |
+
st.session_state.submitted = True
|
69 |
+
|
70 |
+
if 'submitted' in st.session_state:
|
71 |
+
with st.spinner('Processing...'):
|
72 |
+
result, latency = get_result(query_input, compression_retriever)
|
73 |
+
st.subheader("Please find the retrieved documents below 👇")
|
74 |
+
st.write("latency:", latency, " ms")
|
75 |
+
st.json(result)
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
if __name__ == "__main__":
|
80 |
+
main()
|
benchmark.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from retrieval_pipeline import get_relevant_documents
|
3 |
+
import tqdm, time
|
4 |
+
|
5 |
+
|
6 |
+
TOP_N = 3
|
7 |
+
|
8 |
+
def get_benchmark_result(path, retriever):
|
9 |
+
df = pd.read_csv(path)
|
10 |
+
retrieval_result = []
|
11 |
+
query_result = [[] for i in range(TOP_N)]
|
12 |
+
retrieval_latency = []
|
13 |
+
|
14 |
+
# j = 0
|
15 |
+
for i, row in tqdm.tqdm(df.iterrows()):
|
16 |
+
# j+=1
|
17 |
+
query = row['query']
|
18 |
+
target = row['body']
|
19 |
+
|
20 |
+
|
21 |
+
t0 = time.time()
|
22 |
+
results = retriever.get_relevant_documents(query)
|
23 |
+
t = time.time() - t0
|
24 |
+
retrieval_latency.append(t)
|
25 |
+
|
26 |
+
result_content = [result.page_content for result in results]
|
27 |
+
# results_content = get_relevant_documents(query, retriever, top_k=5)
|
28 |
+
|
29 |
+
for i, text in enumerate(result_content):
|
30 |
+
query_result[i] = text
|
31 |
+
|
32 |
+
if target in result_content:
|
33 |
+
retrieval_result.append("Success")
|
34 |
+
else:
|
35 |
+
retrieval_result.append("Failed")
|
36 |
+
# if j>20:
|
37 |
+
# break
|
38 |
+
|
39 |
+
df["retrieval_result"] = retrieval_result
|
40 |
+
df["retrieval_latency"] = retrieval_latency
|
41 |
+
for i in range(TOP_N):
|
42 |
+
df[f'q{i+1}'] = query_result[i]
|
43 |
+
df.to_csv('benchmark_result q3 topk 5.csv')
|
44 |
+
print(df['retrieval_result'].value_counts())
|
45 |
+
print(df['retrieval_result'].value_counts()/ len(df))
|
46 |
+
|
47 |
+
|
main.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dotenv import load_dotenv
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import uuid
|
5 |
+
|
6 |
+
from retrieval_pipeline import get_retriever, get_compression_retriever
|
7 |
+
import benchmark
|
8 |
+
|
9 |
+
load_dotenv()
|
10 |
+
ELASTICSEARCH_URL = os.getenv('ELASTICSEARCH_URL')
|
11 |
+
# HUGGINGFACE_KEY = os.getenv('HUGGINGFACE_KEY')
|
12 |
+
|
13 |
+
os.environ["ES_ENDPOINT"] = ELASTICSEARCH_URL
|
14 |
+
print(ELASTICSEARCH_URL)
|
15 |
+
|
16 |
+
if __name__ == "__main__":
|
17 |
+
retriever = get_retriever(index='masa.ai', elasticsearch_url=ELASTICSEARCH_URL)
|
18 |
+
compression_retriever = get_compression_retriever(retriever)
|
19 |
+
retrieved_chunks = compression_retriever.get_relevant_documents('Gunung Semeru')
|
20 |
+
print(retrieved_chunks)
|
21 |
+
|
22 |
+
# retrieved_chunks = retriever.get_relevant_documents('Gunung Semeru')
|
23 |
+
# print(retrieved_chunks)
|
24 |
+
|
25 |
+
benchmark.get_benchmark_result("benchmark-reranker.csv", retriever=compression_retriever)
|
26 |
+
|
27 |
+
# for i in range(100):
|
28 |
+
# query = input("query: ")
|
29 |
+
# retrieved_chunks = retriever.get_relevant_documents(query)
|
30 |
+
# print("Result:")
|
31 |
+
# for r in retrieved_chunks:
|
32 |
+
# print(r.page_content[:50])
|
33 |
+
# print()
|
retrieval_pipeline/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from retrieval_pipeline.main import *
|
2 |
+
from retrieval_pipeline.hybrid_search import *
|
retrieval_pipeline/hybrid_search.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.vectorstores import VectorStoreRetriever
|
2 |
+
from langchain_community.retrievers import ElasticSearchBM25Retriever
|
3 |
+
from langchain_community.vectorstores.elastic_vector_search import ElasticVectorSearch
|
4 |
+
from langchain_core.callbacks.manager import AsyncCallbackManagerForRetrieverRun
|
5 |
+
from langchain_core.retrievers import BaseRetriever
|
6 |
+
from langchain_core.documents import Document
|
7 |
+
import elasticsearch
|
8 |
+
|
9 |
+
|
10 |
+
from typing import Optional, List
|
11 |
+
|
12 |
+
|
13 |
+
class HybridRetriever(BaseRetriever):
|
14 |
+
dense_db: ElasticVectorSearch
|
15 |
+
dense_retriever: VectorStoreRetriever
|
16 |
+
sparse_retriever: ElasticSearchBM25Retriever
|
17 |
+
index_dense: str
|
18 |
+
index_sparse: str
|
19 |
+
top_k_dense: int
|
20 |
+
top_k_sparse: int
|
21 |
+
|
22 |
+
is_training: bool = False
|
23 |
+
|
24 |
+
@classmethod
|
25 |
+
def create(
|
26 |
+
cls, dense_db, dense_retriever, sparse_retriever, index_dense, index_sparse, top_k_dense, top_k_sparse
|
27 |
+
):
|
28 |
+
|
29 |
+
return cls(
|
30 |
+
dense_db=dense_db,
|
31 |
+
dense_retriever=dense_retriever,
|
32 |
+
sparse_retriever=sparse_retriever,
|
33 |
+
index_dense=index_dense,
|
34 |
+
index_sparse=index_sparse,
|
35 |
+
top_k_dense=top_k_dense,
|
36 |
+
top_k_sparse=top_k_sparse,
|
37 |
+
)
|
38 |
+
|
39 |
+
def reset_indices(self):
|
40 |
+
result = self.dense_db.client.indices.delete(
|
41 |
+
index=self.index_dense,
|
42 |
+
ignore_unavailable=True,
|
43 |
+
allow_no_indices=True,
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
logging.info('dense_db delete:', result)
|
48 |
+
|
49 |
+
result = self.sparse_retriever.client.indices.delete(
|
50 |
+
index=self.index_sparse,
|
51 |
+
ignore_unavailable=True,
|
52 |
+
allow_no_indices=True,
|
53 |
+
)
|
54 |
+
|
55 |
+
logging.info('sparse_retriever delete:', result)
|
56 |
+
|
57 |
+
def add_documents(self, documents, batch_size=25):
|
58 |
+
for i in range(0, len(documents), batch_size):
|
59 |
+
print('batch', i)
|
60 |
+
dense_batch = documents[i:i + batch_size]
|
61 |
+
sparse_batch = [doc.page_content for doc in dense_batch]
|
62 |
+
self.dense_retriever.add_documents(dense_batch)
|
63 |
+
self.sparse_retriever.add_texts(sparse_batch)
|
64 |
+
|
65 |
+
def _get_relevant_documents(self, query: str, **kwargs):
|
66 |
+
dense_results = self.dense_retriever.get_relevant_documents(query)[:self.top_k_dense]
|
67 |
+
sparse_results = self.sparse_retriever.get_relevant_documents(query)[:self.top_k_sparse]
|
68 |
+
|
69 |
+
# Combine results (you'll need a strategy here)
|
70 |
+
combined_results = dense_results + sparse_results
|
71 |
+
# result_text = [doc.page_content for doc in combined_results]
|
72 |
+
|
73 |
+
# reranked_result = rerank.rerank(query, documents=result_text, model="rerank-lite-1", top_k=self.top_k_dense+self.top_k_sparse)
|
74 |
+
# reranked_result = sorted(reranked_result.results, key=lambda result: result.index)
|
75 |
+
|
76 |
+
# Create LangChain Documents
|
77 |
+
documents = [Document(page_content=doc.page_content, metadata=doc.metadata) for doc in combined_results]
|
78 |
+
# documents = [Document(page_content=doc.page_content, metadata=doc.metadata, relevance_score=result.relevance_score) for result, doc in zip(reranked_result, combined_results)]
|
79 |
+
return documents
|
80 |
+
|
81 |
+
async def aget_relevant_documents(self, query: str):
|
82 |
+
raise NotImplementedError
|
83 |
+
|
84 |
+
def get_dense_db(elasticsearch_url, index_dense, embeddings):
|
85 |
+
dense_db = ElasticVectorSearch(
|
86 |
+
elasticsearch_url=elasticsearch_url,
|
87 |
+
index_name=index_dense,
|
88 |
+
embedding=embeddings,
|
89 |
+
)
|
90 |
+
return dense_db
|
91 |
+
|
92 |
+
def get_sparse_retriever(elasticsearch_url, index_sparse):
|
93 |
+
sparse_retriever = ElasticSearchBM25Retriever(client=elasticsearch.Elasticsearch(elasticsearch_url),
|
94 |
+
index_name=index_sparse)
|
95 |
+
return sparse_retriever
|
retrieval_pipeline/main.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.vectorstores import ElasticVectorSearch
|
2 |
+
from langchain.llms import OpenAI, HuggingFaceHub
|
3 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
4 |
+
from retrieval_pipeline.hybrid_search import HybridRetriever, get_dense_db, get_sparse_retriever
|
5 |
+
from retrieval_pipeline.utils import get_hybrid_indexes
|
6 |
+
|
7 |
+
from langchain.retrievers import ContextualCompressionRetriever
|
8 |
+
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
9 |
+
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
10 |
+
|
11 |
+
import logging
|
12 |
+
import tqdm
|
13 |
+
|
14 |
+
|
15 |
+
def get_compression_retriever(retriever):
|
16 |
+
model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
|
17 |
+
compressor = CrossEncoderReranker(model=model, top_n=3)
|
18 |
+
compression_retriever = ContextualCompressionRetriever(
|
19 |
+
base_compressor=compressor, base_retriever=retriever
|
20 |
+
)
|
21 |
+
return compression_retriever
|
22 |
+
|
23 |
+
# Embedding Models Loader
|
24 |
+
def get_huggingface_embeddings(model_name):
|
25 |
+
logging.info(f"Loading Huggingface Embedding")
|
26 |
+
embeddings = HuggingFaceEmbeddings(model_name=model_name)
|
27 |
+
return embeddings
|
28 |
+
|
29 |
+
def get_vectorstore(index_name, embeddings, elasticsearch_url=None):
|
30 |
+
logging.info(f"Loading vectorstore")
|
31 |
+
|
32 |
+
index_dense, index_sparse = get_hybrid_indexes(index_name)
|
33 |
+
|
34 |
+
dense_db = get_dense_db(elasticsearch_url, index_dense, embeddings)
|
35 |
+
dense_retriever = dense_db.as_retriever()
|
36 |
+
|
37 |
+
sparse_retriever = get_sparse_retriever(elasticsearch_url, index_sparse)
|
38 |
+
|
39 |
+
hybrid_retriever = HybridRetriever(
|
40 |
+
dense_db=dense_db,
|
41 |
+
dense_retriever=dense_retriever,
|
42 |
+
sparse_retriever=sparse_retriever,
|
43 |
+
index_dense=index_dense,
|
44 |
+
index_sparse=index_sparse,
|
45 |
+
top_k_dense=2,
|
46 |
+
top_k_sparse=3
|
47 |
+
)
|
48 |
+
|
49 |
+
# db = ElasticVectorSearch(
|
50 |
+
# elasticsearch_url=elasticsearch_url,
|
51 |
+
# index_name=index_name,
|
52 |
+
# embedding=embeddings,
|
53 |
+
# )
|
54 |
+
return hybrid_retriever
|
55 |
+
|
56 |
+
def get_retriever(index, elasticsearch_url):
|
57 |
+
# cache.init(pre_embedding_func=get_msg_func)
|
58 |
+
# cache.set_openai_key(openai_api_key)
|
59 |
+
|
60 |
+
embeddings = get_huggingface_embeddings(model_name="multilingual-e5-large")
|
61 |
+
|
62 |
+
# llm = get_openai_llm(
|
63 |
+
# model_name=model_name, temperature=0, api_key=model_api_key
|
64 |
+
# )
|
65 |
+
# embeddings = get_openai_embeddings(embedding_api_key, embedding_name)
|
66 |
+
|
67 |
+
# question_generator = load_question_generator(llm)
|
68 |
+
# answer_generator = load_answer_generator(llm, company=model_config['company_name'], tone=model_config['tone'], additional_instructions=model_config['additional_instructions'])
|
69 |
+
|
70 |
+
retriever = get_vectorstore(
|
71 |
+
index,
|
72 |
+
embeddings=embeddings,
|
73 |
+
elasticsearch_url=elasticsearch_url,
|
74 |
+
)
|
75 |
+
|
76 |
+
# if history:
|
77 |
+
# qa = get_conversational_chain(retriever, question_generator, answer_generator)
|
78 |
+
# else:
|
79 |
+
# qa = get_retrieval_chain(retriever, answer_generator)
|
80 |
+
|
81 |
+
# chain = CustomLLMChain(
|
82 |
+
# chain=qa,
|
83 |
+
# model_name=llm.model_name,
|
84 |
+
# use_history=history
|
85 |
+
# )
|
86 |
+
#
|
87 |
+
#
|
88 |
+
return retriever
|
89 |
+
|
90 |
+
def get_relevant_documents(query, retriever, top_k):
|
91 |
+
results = retriever.get_relevant_documents(query)
|
92 |
+
passages = [{
|
93 |
+
"id": i,
|
94 |
+
"text": result.page_content
|
95 |
+
} for i, result in enumerate(results)]
|
96 |
+
|
97 |
+
reranked_result = ranker.rerank(RerankRequest(query=query, passages=passages))
|
98 |
+
return reranked_result
|
retrieval_pipeline/utils.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def get_hybrid_indexes(index_name):
|
2 |
+
index_dense = f'{index_name}-dense'
|
3 |
+
index_sparse = f'{index_name}-sparse'
|
4 |
+
|
5 |
+
return index_dense, index_sparse
|