Spaces:
Sleeping
Sleeping
Commit
·
beb7154
1
Parent(s):
3676673
Fix page always restarting
Browse files- .gitignore +1 -0
- app.py +39 -39
- calculate_mmr.py +34 -0
- retrieval_pipeline/cache.py +42 -44
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
**/__pycache__/
|
app.py
CHANGED
@@ -5,12 +5,13 @@ import os, time
|
|
5 |
import uuid
|
6 |
|
7 |
from retrieval_pipeline import get_retriever, get_compression_retriever
|
|
|
8 |
import benchmark
|
9 |
|
10 |
|
11 |
-
def get_result(query,
|
12 |
t0 = time.time()
|
13 |
-
retrieved_chunks =
|
14 |
latency = time.time() - t0
|
15 |
return retrieved_chunks, latency
|
16 |
|
@@ -19,62 +20,61 @@ st.set_page_config(
|
|
19 |
page_title="Retrieval Demo"
|
20 |
)
|
21 |
|
22 |
-
|
|
|
|
|
23 |
load_dotenv()
|
24 |
ELASTICSEARCH_URL = os.getenv('ELASTICSEARCH_URL')
|
25 |
|
26 |
retriever = get_retriever(index='masa.ai', elasticsearch_url=ELASTICSEARCH_URL)
|
27 |
compression_retriever = get_compression_retriever(retriever)
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
|
31 |
def main():
|
32 |
st.title("Part 3: Search")
|
33 |
-
|
34 |
-
# menu = ["Nano", "Small", "Medium", "Large"]
|
35 |
-
# choice = st.sidebar.selectbox("Choose", menu)
|
36 |
|
37 |
st.sidebar.info("""
|
38 |
-
**
|
39 |
-
- **
|
40 |
-
- **
|
41 |
-
- **
|
42 |
-
- **Large**: ~150MB, slower model with competitive performance (ranking precision) for 100+ languages.
|
43 |
""")
|
44 |
|
45 |
with st.spinner('Setting up...'):
|
46 |
-
|
47 |
|
48 |
-
with st.expander("Tech Stack Used"):
|
49 |
-
st.markdown("""
|
50 |
-
**Flash Rank**: Ultra-lite & Super-fast Python library for search & retrieval re-ranking.
|
51 |
|
52 |
-
|
53 |
-
- **Super-fast**: Speed depends on the number of tokens in passages and query, plus model depth.
|
54 |
-
- **Cost-efficient**: Ideal for serverless deployments with low memory and time requirements.
|
55 |
-
- **Based on State-of-the-Art Cross-encoders**: Includes models like ms-marco-TinyBERT-L-2-v2 (default), ms-marco-MiniLM-L-12-v2, rank-T5-flan, and ms-marco-MultiBERT-L-12.
|
56 |
-
- **Sleek Models for Efficiency**: Designed for minimal overhead in user-facing scenarios.
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
with st.form(key='input_form'):
|
63 |
-
query_input = st.text_area("Query Input")
|
64 |
-
# context_input = st.text_area("Context Input")
|
65 |
-
submit_button = st.form_submit_button(label='Retrieve')
|
66 |
-
|
67 |
-
if submit_button:
|
68 |
-
st.session_state.submitted = True
|
69 |
-
|
70 |
-
if 'submitted' in st.session_state:
|
71 |
-
with st.spinner('Processing...'):
|
72 |
-
result, latency = get_result(query_input, compression_retriever)
|
73 |
-
st.subheader("Please find the retrieved documents below 👇")
|
74 |
-
st.write("latency:", latency, " ms")
|
75 |
-
st.json(result)
|
76 |
|
|
|
|
|
|
|
|
|
|
|
77 |
|
|
|
|
|
78 |
|
79 |
if __name__ == "__main__":
|
80 |
main()
|
|
|
5 |
import uuid
|
6 |
|
7 |
from retrieval_pipeline import get_retriever, get_compression_retriever
|
8 |
+
from retrieval_pipeline.cache import SemanticCache
|
9 |
import benchmark
|
10 |
|
11 |
|
12 |
+
def get_result(query, retriever, use_cache):
|
13 |
t0 = time.time()
|
14 |
+
retrieved_chunks = retriever.get_relevant_documents(query, use_cache=use_cache)
|
15 |
latency = time.time() - t0
|
16 |
return retrieved_chunks, latency
|
17 |
|
|
|
20 |
page_title="Retrieval Demo"
|
21 |
)
|
22 |
|
23 |
+
|
24 |
+
@st.cache_resource
|
25 |
+
def setup_retriever():
|
26 |
load_dotenv()
|
27 |
ELASTICSEARCH_URL = os.getenv('ELASTICSEARCH_URL')
|
28 |
|
29 |
retriever = get_retriever(index='masa.ai', elasticsearch_url=ELASTICSEARCH_URL)
|
30 |
compression_retriever = get_compression_retriever(retriever)
|
31 |
+
semantic_cache_retriever = SemanticCache(compression_retriever)
|
32 |
+
return semantic_cache_retriever
|
33 |
+
|
34 |
+
|
35 |
+
def retrieval_page(retriever, use_cache):
|
36 |
+
with st.form(key='input_form'):
|
37 |
+
query_input = st.text_area("Query Input")
|
38 |
+
submit_button = st.form_submit_button(label='Retrieve')
|
39 |
+
|
40 |
+
if submit_button:
|
41 |
+
with st.spinner('Processing...'):
|
42 |
+
result, latency = get_result(query_input, retriever=retriever, use_cache=use_cache)
|
43 |
+
st.subheader("Please find the retrieved documents below 👇")
|
44 |
+
st.write("latency:", latency, " s")
|
45 |
+
st.json(result)
|
46 |
+
|
47 |
|
48 |
|
49 |
def main():
|
50 |
st.title("Part 3: Search")
|
51 |
+
use_cache = st.sidebar.toggle("Use cache", value=True)
|
|
|
|
|
52 |
|
53 |
st.sidebar.info("""
|
54 |
+
**Retrieval Pipeline Evaluation Result:**
|
55 |
+
- **MRR**: 0.756
|
56 |
+
- **Avg. Latency**: 4.50s (on CPU, with cache turned off)
|
57 |
+
- **Benchmark Result**: https://docs.google.com/spreadsheets/d/1WJnb8BieoxLch0gvb53ZzMS70r_G35PKm731ubdeNCA/edit?usp=sharing
|
|
|
58 |
""")
|
59 |
|
60 |
with st.spinner('Setting up...'):
|
61 |
+
retriever = setup_retriever()
|
62 |
|
|
|
|
|
|
|
63 |
|
64 |
+
retrieval_page(retriever, use_cache)
|
|
|
|
|
|
|
|
|
65 |
|
66 |
+
# with st.expander("Tech Stack Used"):
|
67 |
+
# st.markdown("""
|
68 |
+
# **Flash Rank**: Ultra-lite & Super-fast Python library for search & retrieval re-ranking.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
+
# - **Ultra-lite**: No heavy dependencies. Runs on CPU with a tiny ~4MB reranking model.
|
71 |
+
# - **Super-fast**: Speed depends on the number of tokens in passages and query, plus model depth.
|
72 |
+
# - **Cost-efficient**: Ideal for serverless deployments with low memory and time requirements.
|
73 |
+
# - **Based on State-of-the-Art Cross-encoders**: Includes models like ms-marco-TinyBERT-L-2-v2 (default), ms-marco-MiniLM-L-12-v2, rank-T5-flan, and ms-marco-MultiBERT-L-12.
|
74 |
+
# - **Sleek Models for Efficiency**: Designed for minimal overhead in user-facing scenarios.
|
75 |
|
76 |
+
# _Flash Rank is tailored for scenarios requiring efficient and effective reranking, balancing performance with resource usage._
|
77 |
+
# """)
|
78 |
|
79 |
if __name__ == "__main__":
|
80 |
main()
|
calculate_mmr.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
def find_reciprocal_rank(target, row, u, k):
|
5 |
+
for i in range(k):
|
6 |
+
q = row['q{}'.format(i+1)]
|
7 |
+
if target == q:
|
8 |
+
print(1/(i+1))
|
9 |
+
return 1/(i+1)
|
10 |
+
return 0
|
11 |
+
|
12 |
+
def main(filename, k):
|
13 |
+
df = pd.read_csv(filename)
|
14 |
+
u = len(df)
|
15 |
+
|
16 |
+
sum_ = 0
|
17 |
+
for _, row in df.iterrows():
|
18 |
+
target = row['body']
|
19 |
+
reciprocal_rank = find_reciprocal_rank(target, row, u, k)
|
20 |
+
sum_ += reciprocal_rank
|
21 |
+
mrr = sum_ / u
|
22 |
+
|
23 |
+
print('U:', u)
|
24 |
+
print('MRR: ', mrr)
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
if __name__ == '__main__':
|
29 |
+
parser = argparse.ArgumentParser()
|
30 |
+
parser.add_argument('filename')
|
31 |
+
parser.add_argument('-k', type=int)
|
32 |
+
args = parser.parse_args()
|
33 |
+
|
34 |
+
main(filename=args.filename, k=args.k)
|
retrieval_pipeline/cache.py
CHANGED
@@ -46,49 +46,47 @@ class SemanticCache:
|
|
46 |
results = self.retriever.get_relevant_documents(query_text)
|
47 |
return results
|
48 |
|
49 |
-
def get_relevant_documents(self, query: str) -> str:
|
50 |
# Method to retrieve an answer from the cache or generate a new one
|
51 |
start_time = time.time()
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
# except Exception as e:
|
94 |
-
# raise RuntimeError(f"Error during 'get_relevant_documents' method: {e}")
|
|
|
46 |
results = self.retriever.get_relevant_documents(query_text)
|
47 |
return results
|
48 |
|
49 |
+
def get_relevant_documents(self, query: str, use_cache=True) -> str:
|
50 |
# Method to retrieve an answer from the cache or generate a new one
|
51 |
start_time = time.time()
|
52 |
+
try:
|
53 |
+
# First we obtain the embeddings corresponding to the user query
|
54 |
+
embedding = self.encoder.encode([query])
|
55 |
+
|
56 |
+
# Search for the nearest neighbor in the index
|
57 |
+
self.index.nprobe = 8
|
58 |
+
D, I = self.index.search(embedding, 1)
|
59 |
+
|
60 |
+
if use_cache:
|
61 |
+
if D[0] >= 0:
|
62 |
+
if I[0][0] >= 0 and D[0][0] <= self.euclidean_threshold:
|
63 |
+
row_id = int(I[0][0])
|
64 |
+
|
65 |
+
print("Answer recovered from Cache. ")
|
66 |
+
print(f"{D[0][0]:.3f} smaller than {self.euclidean_threshold}")
|
67 |
+
print(f"Found cache in row: {row_id} with score {D[0][0]:.3f}")
|
68 |
+
|
69 |
+
end_time = time.time()
|
70 |
+
elapsed_time = end_time - start_time
|
71 |
+
print(f"Time taken: {elapsed_time:.3f} seconds")
|
72 |
+
return [Document(**doc) for doc in self.cache["answers"][row_id]]
|
73 |
+
|
74 |
+
# Handle the case when there are not enough results
|
75 |
+
# or Euclidean distance is not met, asking to chromaDB.
|
76 |
+
answer = self.query_database(query)
|
77 |
+
# response_text = answer["documents"][0][0]
|
78 |
+
|
79 |
+
self.cache["query"].append(query)
|
80 |
+
self.cache["embeddings"].append(embedding[0].tolist())
|
81 |
+
self.cache["answers"].append([doc.__dict__ for doc in answer])
|
82 |
+
|
83 |
+
|
84 |
+
self.index.add(embedding)
|
85 |
+
store_cache(self.json_file, self.cache)
|
86 |
+
end_time = time.time()
|
87 |
+
elapsed_time = end_time - start_time
|
88 |
+
print(f"Time taken: {elapsed_time:.3f} seconds")
|
89 |
+
|
90 |
+
return answer
|
91 |
+
except Exception as e:
|
92 |
+
raise RuntimeError(f"Error during 'get_relevant_documents' method: {e}")
|
|
|
|