Jai12345 commited on
Commit
859a24c
·
1 Parent(s): 3c34743

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -58
app.py CHANGED
@@ -1,14 +1,14 @@
1
  from sentence_transformers import SentenceTransformer, CrossEncoder, util
2
  import re
3
- import newspaper
 
4
  import docx2txt
5
  from io import StringIO
6
  from PyPDF2 import PdfFileReader
7
  import validators
8
- import streamlit as st
9
  import nltk
10
- import pandas as pd
11
- import requests
12
 
13
  nltk.download('punkt')
14
 
@@ -17,40 +17,62 @@ from nltk import sent_tokenize
17
  warnings.filterwarnings("ignore")
18
 
19
 
20
- def extarct_test_from_url(url: str):
21
- article = Article(url,language-"en")
 
 
22
  article.download()
23
- article.parse
24
 
25
- # receiving text
26
  text = article.text
 
 
27
  title = article.title
 
28
  return title, text
29
 
30
 
31
  def extract_text_from_file(file):
 
 
 
32
  if file.type == "text/plain":
33
  # To convert to a string based IO:
34
  stringio = StringIO(file.getvalue().decode("utf-8"))
 
 
35
  file_text = stringio.read()
 
36
  return file_text, None
 
 
37
  elif file.type == "application/pdf":
38
  pdfReader = PdfFileReader(file)
39
  count = pdfReader.numPages
40
  all_text = ""
41
  pdf_title = pdfReader.getDocumentInfo().title
 
42
  for i in range(count):
 
43
  try:
44
  page = pdfReader.getPage(i)
45
  all_text += page.extractText()
 
46
  except:
47
  continue
 
48
  file_text = all_text
 
49
  return file_text, pdf_title
50
- # read docx file
 
51
  elif (
52
- file.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"):
 
 
53
  file_text = docx2txt.process(file)
 
54
  return file_text, None
55
 
56
 
@@ -60,13 +82,18 @@ def preprocess_plain_text(text, window_size=3):
60
  text = re.sub(r"@\S+", " ", text) # mentions
61
  text = re.sub(r"#\S+", " ", text) # hastags
62
  text = re.sub(r"\s{2,}", " ", text) # over spaces
63
- text = re.sub("[^.,!?%$A-Za-z0-9]+", " ", text) # special characters except .,!?
64
- # removing spaces
 
65
  lines = [line.strip() for line in text.splitlines()]
66
- # break multi-headlines into a line each
 
67
  chunks = [phrase.strip() for line in lines for phrase in line.split(" ")]
68
- # drop blank lines
 
69
  text = '\n'.join(chunk for chunk in chunks if chunk)
 
 
70
  paragraphs = []
71
  for paragraph in text.replace('\n', ' ').split("\n\n"):
72
  if len(paragraph.strip()) > 0:
@@ -79,13 +106,21 @@ def preprocess_plain_text(text, window_size=3):
79
  end_idx = min(start_idx + window_size, len(paragraph))
80
  passages.append(" ".join(paragraph[start_idx:end_idx]))
81
 
 
 
 
82
  return passages
83
 
84
 
85
- def biencode(bi_enc, passages):
 
86
  global bi_encoder
 
87
  bi_encoder = SentenceTransformer(bi_enc)
88
 
 
 
 
89
  # Compute the embeddings using the multi-process pool
90
  with st.spinner('Encoding passages into a vector space...'):
91
  corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)
@@ -95,13 +130,17 @@ def biencode(bi_enc, passages):
95
  return bi_encoder, corpus_embeddings
96
 
97
 
 
98
  def cross_encode():
99
  global cross_encoder
100
- # cross-encoder to re-rank the results list to improve the quality
101
  cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
102
  return cross_encoder
103
 
104
 
 
 
 
105
  def display_as_table(model, top_k, score='score'):
106
  # Display the df with text and scores as a table
107
  df = pd.DataFrame([(hit[score], passages[hit['corpus_id']]) for hit in model[0:top_k]], columns=['Score', 'Text'])
@@ -110,77 +149,117 @@ def display_as_table(model, top_k, score='score'):
110
  return df
111
 
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  def search_func(query, top_k=top_k):
114
  global bi_encoder, cross_encoder
 
115
  st.subheader(f"Search Query: {query}")
 
116
  if url_text:
 
117
  st.write(f"Document Header: {title}")
 
118
  elif pdf_title:
 
119
  st.write(f"Document Header: {pdf_title}")
 
 
120
  question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
121
  question_embedding = question_embedding.cpu()
122
  hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k, score_function=util.dot_score)
123
- hits = hits[0]
124
- # score all retrieved passages with the cross_encoder
 
125
  cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
126
  cross_scores = cross_encoder.predict(cross_inp)
 
127
  # Sort results by the cross-encoder scores
128
  for idx in range(len(cross_scores)):
129
  hits[idx]['cross-score'] = cross_scores[idx]
130
- # Output of top-3 hits
 
 
 
 
 
 
 
 
 
131
  st.markdown("\n-------------------------\n")
132
  st.subheader(f"Top-{top_k} Cross-Encoder Re-ranker hits")
133
  hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
134
 
135
- rerank_df = display_df_as_table(hits, top_k, 'cross-score')
136
  st.write(rerank_df.to_html(index=False), unsafe_allow_html=True)
137
 
138
- def clear_text():
139
- st.session_state["text_url"] = ""
140
- st.session_state["text_input"] = ""
141
 
142
- def clear_search_text():
143
- st.session_state["text_input"] = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
- url_text = st.text_input("Enter a url here", value="https://en.wikipedia.org/wiki/Virat_Kohli", key='text_url',
146
- on_change=clear_search_text)
147
- st.markdown(
148
- "<h3 style='text-align: center; color: red;'>OR</h3>",
149
- unsafe_allow_html=True, )
150
- upload_doc = st.file_uploader("Upload a .txt, .pdf, .docx file", key="upload")
151
 
152
- search_query = st.text_input("Enter your search query here", value="How many Centuries Virat Kohli scored?",
153
- key="text_input")
154
- if validators.url(url_text):
155
- # if input is URL
156
- title, text = extract_text_from_url(url_text)
157
- passages = preprocess_plain_text(text, window_size=window_size)
158
 
159
- elif upload_doc:
160
 
161
- text, pdf_title = extract_text_from_file(upload_doc)
162
- passages = preprocess_plain_text(text, window_size=window_size)
163
 
164
- col1, col2 = st.columns(2)
165
 
166
- with col1:
167
- search = st.button("Search", key='search_but', help='Click to Search!!')
168
 
169
- with col2:
170
- clear = st.button("Clear Text Input", on_click=clear_text, key='clear',
171
- help='Click to clear the URL input and search query')
172
 
173
- if search:
174
- if bi_encoder_type:
175
- with st.spinner(
176
- text=f"Loading {bi_encoder_type} bi-encoder and embedding document into vector space. This might take a few seconds depending on the length of your document..."
177
- ):
178
- bi_encoder, corpus_embeddings = bi_encode(bi_encoder_type, passages)
179
- cross_encoder = cross_encode()
180
- bm25 = bm25_api(passages)
181
 
182
- with st.spinner(
183
- text="Embedding completed, searching for relevant text for given query and hits..."):
184
- search_func(search_query, top_k)
185
 
186
- st.markdown(""" """)
 
 
1
  from sentence_transformers import SentenceTransformer, CrossEncoder, util
2
  import re
3
+ import pandas as pd
4
+ from newspaper import Article
5
  import docx2txt
6
  from io import StringIO
7
  from PyPDF2 import PdfFileReader
8
  import validators
 
9
  import nltk
10
+ import warnings
11
+ import streamlit as st
12
 
13
  nltk.download('punkt')
14
 
 
17
  warnings.filterwarnings("ignore")
18
 
19
 
20
+ def extract_text_from_url(url: str):
21
+ '''Extract text from url'''
22
+
23
+ article = Article(url)
24
  article.download()
25
+ article.parse()
26
 
27
+ # get text
28
  text = article.text
29
+
30
+ # get article title
31
  title = article.title
32
+
33
  return title, text
34
 
35
 
36
  def extract_text_from_file(file):
37
+ '''Extract text from uploaded file'''
38
+
39
+ # read text file
40
  if file.type == "text/plain":
41
  # To convert to a string based IO:
42
  stringio = StringIO(file.getvalue().decode("utf-8"))
43
+
44
+ # To read file as string:
45
  file_text = stringio.read()
46
+
47
  return file_text, None
48
+
49
+ # read pdf file
50
  elif file.type == "application/pdf":
51
  pdfReader = PdfFileReader(file)
52
  count = pdfReader.numPages
53
  all_text = ""
54
  pdf_title = pdfReader.getDocumentInfo().title
55
+
56
  for i in range(count):
57
+
58
  try:
59
  page = pdfReader.getPage(i)
60
  all_text += page.extractText()
61
+
62
  except:
63
  continue
64
+
65
  file_text = all_text
66
+
67
  return file_text, pdf_title
68
+
69
+ # read docx file
70
  elif (
71
+ file.type
72
+ == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
73
+ ):
74
  file_text = docx2txt.process(file)
75
+
76
  return file_text, None
77
 
78
 
 
82
  text = re.sub(r"@\S+", " ", text) # mentions
83
  text = re.sub(r"#\S+", " ", text) # hastags
84
  text = re.sub(r"\s{2,}", " ", text) # over spaces
85
+ # text = re.sub("[^.,!?%$A-Za-z0-9]+", " ", text) # special characters except .,!?
86
+
87
+ # break into lines and remove leading and trailing space on each
88
  lines = [line.strip() for line in text.splitlines()]
89
+
90
+ # #break multi-headlines into a line each
91
  chunks = [phrase.strip() for line in lines for phrase in line.split(" ")]
92
+
93
+ # # drop blank lines
94
  text = '\n'.join(chunk for chunk in chunks if chunk)
95
+
96
+ ## We split this article into paragraphs and then every paragraph into sentences
97
  paragraphs = []
98
  for paragraph in text.replace('\n', ' ').split("\n\n"):
99
  if len(paragraph.strip()) > 0:
 
106
  end_idx = min(start_idx + window_size, len(paragraph))
107
  passages.append(" ".join(paragraph[start_idx:end_idx]))
108
 
109
+ st.write(f"Sentences: {sum([len(p) for p in paragraphs])}")
110
+ st.write(f"Passages: {len(passages)}")
111
+
112
  return passages
113
 
114
 
115
+ @st.experimental_memo(suppress_st_warning=True)
116
+ def bi_encode(bi_enc, passages):
117
  global bi_encoder
118
+ # We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
119
  bi_encoder = SentenceTransformer(bi_enc)
120
 
121
+ # quantize the model
122
+ # bi_encoder = quantize_dynamic(model, {Linear, Embedding})
123
+
124
  # Compute the embeddings using the multi-process pool
125
  with st.spinner('Encoding passages into a vector space...'):
126
  corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)
 
130
  return bi_encoder, corpus_embeddings
131
 
132
 
133
+ @st.experimental_singleton(suppress_st_warning=True)
134
  def cross_encode():
135
  global cross_encoder
136
+ # We use a cross-encoder, to re-rank the results list to improve the quality
137
  cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
138
  return cross_encoder
139
 
140
 
141
+ bi_enc_options = ["multi-qa-mpnet-base-dot-v1", "all-mpnet-base-v2", "multi-qa-MiniLM-L6-cos-v1"]
142
+
143
+
144
  def display_as_table(model, top_k, score='score'):
145
  # Display the df with text and scores as a table
146
  df = pd.DataFrame([(hit[score], passages[hit['corpus_id']]) for hit in model[0:top_k]], columns=['Score', 'Text'])
 
149
  return df
150
 
151
 
152
+ # Streamlit App
153
+
154
+ st.title("Semantic Search with Retrieve & Rerank 📝")
155
+
156
+ window_size = st.sidebar.slider("Paragraph Window Size", min_value=1, max_value=10, value=3, key=
157
+ 'slider')
158
+
159
+ bi_encoder_type = st.sidebar.selectbox("Bi-Encoder", options=bi_enc_options, key='sbox')
160
+
161
+ top_k = st.sidebar.slider("Number of Top Hits Generated", min_value=1, max_value=5, value=2)
162
+
163
+
164
+ # This function will search all wikipedia articles for passages that
165
+ # answer the query
166
  def search_func(query, top_k=top_k):
167
  global bi_encoder, cross_encoder
168
+
169
  st.subheader(f"Search Query: {query}")
170
+
171
  if url_text:
172
+
173
  st.write(f"Document Header: {title}")
174
+
175
  elif pdf_title:
176
+
177
  st.write(f"Document Header: {pdf_title}")
178
+
179
+ # Encode the query using the bi-encoder and find potentially relevant passages
180
  question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
181
  question_embedding = question_embedding.cpu()
182
  hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k, score_function=util.dot_score)
183
+ hits = hits[0] # Get the hits for the first query
184
+
185
+ # Now, score all retrieved passages with the cross_encoder
186
  cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
187
  cross_scores = cross_encoder.predict(cross_inp)
188
+
189
  # Sort results by the cross-encoder scores
190
  for idx in range(len(cross_scores)):
191
  hits[idx]['cross-score'] = cross_scores[idx]
192
+
193
+ # Output of top-3 hits from bi-encoder
194
+ st.markdown("\n-------------------------\n")
195
+ st.subheader(f"Top-{top_k} Bi-Encoder Retrieval hits")
196
+ hits = sorted(hits, key=lambda x: x['score'], reverse=True)
197
+
198
+ cross_df = display_as_table(hits, top_k)
199
+ st.write(cross_df.to_html(index=False), unsafe_allow_html=True)
200
+
201
+ # Output of top-3 hits from re-ranker
202
  st.markdown("\n-------------------------\n")
203
  st.subheader(f"Top-{top_k} Cross-Encoder Re-ranker hits")
204
  hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
205
 
206
+ rerank_df = display_as_table(hits, top_k, 'cross-score')
207
  st.write(rerank_df.to_html(index=False), unsafe_allow_html=True)
208
 
 
 
 
209
 
210
+ def clear_text():
211
+ st.session_state["text_url"] = ""
212
+ st.session_state["text_input"] = ""
213
+
214
+
215
+ def clear_search_text():
216
+ st.session_state["text_input"] = ""
217
+
218
+
219
+ url_text = st.text_input("Please Enter a url here",
220
+ value="https://www.rba.gov.au/monetary-policy/rba-board-minutes/2022/2022-05-03.html",
221
+ key='text_url', on_change=clear_search_text)
222
+
223
+ st.markdown(
224
+ "<h3 style='text-align: center; color: red;'>OR</h3>",
225
+ unsafe_allow_html=True,
226
+ )
227
+
228
+ upload_doc = st.file_uploader("Upload a .txt, .pdf, .docx file", key="upload")
229
 
230
+ search_query = st.text_input("Please Enter your search query here",
231
+ value="What are the expectations for inflation for Australia?", key="text_input")
 
 
 
 
232
 
233
+ if validators.url(url_text):
234
+ # if input is URL
235
+ title, text = extract_text_from_url(url_text)
236
+ passages = preprocess_plain_text(text, window_size=window_size)
 
 
237
 
238
+ elif upload_doc:
239
 
240
+ text, pdf_title = extract_text_from_file(upload_doc)
241
+ passages = preprocess_plain_text(text, window_size=window_size)
242
 
243
+ col1, col2 = st.columns(2)
244
 
245
+ with col1:
246
+ search = st.button("Search", key='search_but', help='Click to Search!!')
247
 
248
+ with col2:
249
+ clear = st.button("Clear Text Input", on_click=clear_text, key='clear',
250
+ help='Click to clear the URL input and search query')
251
 
252
+ if search:
253
+ if bi_encoder_type:
254
+ with st.spinner(
255
+ text=f"Loading {bi_encoder_type} bi-encoder and embedding document into vector space. This might take a few seconds depending on the length of your document..."
256
+ ):
257
+ bi_encoder, corpus_embeddings = bi_encode(bi_encoder_type, passages)
258
+ cross_encoder = cross_encode()
 
259
 
260
+ with st.spinner(
261
+ text="Embedding completed, searching for relevant text for given query and hits..."):
262
+ search_func(search_query, top_k)
263
 
264
+ st.markdown("""
265
+ """)