Konrad Wojtasik commited on
Commit
0c2b47c
1 Parent(s): 6f972fa
Files changed (2) hide show
  1. app.py +224 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from sentence_transformers import SentenceTransformer, CrossEncoder, util
3
+ import os, re
4
+ import torch
5
+ from rank_bm25 import BM25Okapi
6
+ from sklearn.feature_extraction import _stop_words
7
+ import string
8
+ import numpy as np
9
+ import pandas as pd
10
+ import base64
11
+ from io import StringIO
12
+ import validators
13
+ import nltk
14
+ import warnings
15
+ import streamlit as st
16
+ from PIL import Image
17
+ from beir.datasets.data_loader_hf import HFDataLoader
18
+ from beir.reranking.models.mono_t5 import MonoT5
19
+
20
+
21
+
22
+ warnings.filterwarnings("ignore")
23
+
24
+ auth_token = os.environ.get("auth_token")
25
+
26
+ @st.cache_data()
27
+ def load_data(dataset_type):
28
+
29
+ corpus, queries, qrels = HFDataLoader(hf_repo="clarin-knext/"+dataset_type, streaming=False, keep_in_memory=False).load(split="test")
30
+ corpus = [ doc['text']for doc in corpus]
31
+ queries = [ query['text']for query in queries]
32
+ return queries, corpus
33
+
34
+ @st.cache_data()
35
+ def bi_encode(bi_enc,passages):
36
+
37
+ global bi_encoder
38
+ #We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
39
+ bi_encoder = SentenceTransformer(bi_enc,use_auth_token=auth_token)
40
+
41
+ with st.spinner('Encoding passages into a vector space...'):
42
+
43
+ if bi_enc == 'intfloat/multilingual-e5-base':
44
+
45
+ corpus_embeddings = bi_encoder.encode(['passage: ' + sentence for sentence in passages], convert_to_tensor=True)
46
+
47
+ else:
48
+ corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True)
49
+
50
+
51
+ st.success(f"Embeddings computed. Shape: {corpus_embeddings.shape}")
52
+
53
+ return bi_encoder, corpus_embeddings
54
+
55
+ @st.cache_resource()
56
+ def cross_encode(cross_encoder_name):
57
+
58
+ global cross_encoder
59
+ #The bi-encoder will retrieve 100 documents. We use a cross-encoder, to re-rank the results list to improve the quality
60
+ if cross_encoder_name == "clarin-knext/plt5-base-msmarco":
61
+ cross_encoder = MonoT5(cross_encoder_name, use_amp=False, token_true='▁prawda', token_false='▁fałsz')
62
+ else:
63
+ cross_encoder = CrossEncoder(cross_encoder_name)#('cross-encoder/mmarco-mMiniLMv2-L12-H384-v1')
64
+
65
+ return cross_encoder
66
+
67
+ @st.cache_data()
68
+ def bm25_tokenizer(text):
69
+
70
+ # We also compare the results to lexical search (keyword search). Here, we use
71
+ # the BM25 algorithm which is implemented in the rank_bm25 package.
72
+ # We lower case our text and remove stop-words from indexing
73
+ tokenized_doc = []
74
+ for token in text.lower().split():
75
+ token = token.strip(string.punctuation)
76
+
77
+ if len(token) > 0 and token not in _stop_words.ENGLISH_STOP_WORDS:
78
+ tokenized_doc.append(token)
79
+ return tokenized_doc
80
+
81
+ @st.cache_resource()
82
+ def bm25_api(passages):
83
+
84
+ tokenized_corpus = []
85
+
86
+ for passage in passages:
87
+ tokenized_corpus.append(bm25_tokenizer(passage))
88
+
89
+ bm25 = BM25Okapi(tokenized_corpus)
90
+
91
+ return bm25
92
+
93
+ bi_enc_options = ["sentence-transformers/distiluse-base-multilingual-cased-v1", 'intfloat/multilingual-e5-base', 'nthakur/mcontriever-base-msmarco']
94
+ # "all-mpnet-base-v2","multi-qa-MiniLM-L6-cos-v1",'intfloat/e5-base-v2',"neeva/query2query"
95
+ cross_enc_options = [ 'clarin-knext/plt5-base-msmarco', 'clarin-knext/herbert-base-reranker-msmarco', 'cross-encoder/mmarco-mMiniLMv2-L12-H384-v1']
96
+ datasets_options = ["nfcorpus-pl", "scifact-pl", "fiqa-pl"]
97
+
98
+ def display_df_as_table(model,top_k,score='score'):
99
+ # Display the df with text and scores as a table
100
+ df = pd.DataFrame([(hit[score], passages[hit['corpus_id']]) for hit in model[0:top_k]],columns=['Score','Text'])
101
+ df['Score'] = round(df['Score'],2)
102
+
103
+ return df
104
+
105
+ #Streamlit App
106
+
107
+ st.title("Retrieval BEIR-PL Demo")
108
+
109
+ """
110
+ Example of retrieval over BEIR-PL dataset.
111
+ """
112
+
113
+
114
+ # window_size = st.sidebar.slider("Paragraph Window Size",min_value=1,max_value=10,value=3,key=
115
+ # 'slider')
116
+
117
+ st.sidebar.title("Menu")
118
+
119
+ dataset_type = st.sidebar.selectbox("Dataset", options=datasets_options, key='dataset_select')
120
+
121
+ bi_encoder_type = st.sidebar.selectbox("Bi-Encoder", options=bi_enc_options, key='bi_select')
122
+
123
+ cross_encoder_type = st.sidebar.selectbox("Cross-Encoder", options=cross_enc_options, key='cross_select')
124
+
125
+ top_k = st.sidebar.slider("Number of Top Hits Generated",min_value=1,max_value=5,value=2)
126
+
127
+ hide_bm25 = st.sidebar.checkbox("Hide BM25 results?")
128
+ hide_biencoder = st.sidebar.checkbox("Hide Bi-Encoder results?")
129
+ hide_crossencoder = st.sidebar.checkbox("Hide Cross-Encoder results?")
130
+
131
+ # This function will search all wikipedia articles for passages that
132
+ # answer the query
133
+ def search_func(query, bi_encoder_type, top_k=top_k):
134
+
135
+ global bi_encoder, cross_encoder
136
+
137
+ st.subheader(f"Search Query:\n_{query}_")
138
+
139
+ ##### BM25 search (lexical search) #####
140
+ bm25_scores = bm25.get_scores(bm25_tokenizer(query))
141
+ top_n = np.argpartition(bm25_scores, -5)[-5:]
142
+ bm25_hits = [{'corpus_id': idx, 'score': bm25_scores[idx]} for idx in top_n]
143
+ bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True)
144
+
145
+ if not hide_bm25:
146
+ st.subheader(f"Top-{top_k} lexical search (BM25) hits")
147
+
148
+ bm25_df = display_df_as_table(bm25_hits,top_k)
149
+ st.write(bm25_df.to_html(index=False), unsafe_allow_html=True)
150
+
151
+ ##### Sematic Search #####
152
+ # Encode the query using the bi-encoder and find potentially relevant passages
153
+ question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
154
+ question_embedding = question_embedding.cpu()
155
+ hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k,score_function=util.dot_score)
156
+ hits = hits[0] # Get the hits for the first query
157
+
158
+ ##### Re-Ranking #####
159
+ # Now, score all retrieved passages with the cross_encoder
160
+ cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
161
+ cross_scores = cross_encoder.predict(cross_inp)
162
+
163
+ # Sort results by the cross-encoder scores
164
+ for idx in range(len(cross_scores)):
165
+ hits[idx]['cross-score'] = cross_scores[idx]
166
+
167
+ if not hide_biencoder:
168
+ # Output of top-k hits from bi-encoder
169
+ st.markdown("\n-------------------------\n")
170
+ st.subheader(f"Top-{top_k} Bi-Encoder Retrieval hits")
171
+ hits = sorted(hits, key=lambda x: x['score'], reverse=True)
172
+
173
+ biencoder_df = display_df_as_table(hits,top_k)
174
+ st.write(biencoder_df.to_html(index=False), unsafe_allow_html=True)
175
+
176
+ if not hide_crossencoder:
177
+ # Output of top-3 hits from re-ranker
178
+ st.markdown("\n-------------------------\n")
179
+ st.subheader(f"Top-{top_k} Cross-Encoder Re-ranker hits")
180
+ hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
181
+
182
+ rerank_df = display_df_as_table(hits,top_k,'cross-score')
183
+ st.write(rerank_df.to_html(index=False), unsafe_allow_html=True)
184
+
185
+ st.markdown("---")
186
+
187
+ def clear_text():
188
+ st.session_state["text_input"]= ""
189
+
190
+
191
+ question, passages = load_data(dataset_type)
192
+
193
+ st.write(pd.DataFrame(question[:5], columns=["Example queries from dataset"]).to_html(index=False, justify='center'), unsafe_allow_html=True)
194
+
195
+ search_query = st.text_input("Ask your question:",
196
+ value=question[0],
197
+ key="text_input")
198
+
199
+
200
+ col1, col2 = st.columns(2)
201
+
202
+ with col1:
203
+ search = st.button("Search",key='search_but', help='Click to Search!')
204
+
205
+ with col2:
206
+ clear = st.button("Clear Text Input", on_click=clear_text,key='clear',help='Click to clear the search query')
207
+
208
+ if search:
209
+ if bi_encoder_type:
210
+
211
+ with st.spinner(
212
+ text=f"Loading {bi_encoder_type} bi-encoder and embedding document into vector space. This might take a few seconds depending on the length of your document..."
213
+ ):
214
+ bi_encoder, corpus_embeddings = bi_encode(bi_encoder_type,passages)
215
+ cross_encoder = cross_encode(cross_encoder_type)
216
+ bm25 = bm25_api(passages)
217
+
218
+ with st.spinner(
219
+ text="Embedding completed, searching for relevant text for given query and hits..."):
220
+
221
+ search_func(search_query,bi_encoder_type,top_k)
222
+
223
+ st.markdown("""
224
+ """)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ beir==1.0.1
2
+ sentence-transformers==2.2.2
3
+ transformers==4.29.1
4
+ torch==2.0.1
5
+ sentencepiece==0.1.95
6
+ protobuf==3.20.3
7
+ pandas
8
+