|
from sentence_transformers import SentenceTransformer, CrossEncoder, util |
|
import re |
|
import pandas as pd |
|
from newspaper import Article |
|
import docx2txt |
|
from io import StringIO |
|
from PyPDF2 import PdfFileReader |
|
import validators |
|
import nltk |
|
import warnings |
|
import streamlit as st |
|
|
|
nltk.download('punkt') |
|
|
|
from nltk import sent_tokenize |
|
|
|
def extract_text_from_url(url: str): |
|
'''Extract text from url''' |
|
|
|
article = Article(url) |
|
article.download() |
|
article.parse() |
|
|
|
|
|
text = article.text |
|
|
|
|
|
title = article.title |
|
|
|
return title, text |
|
|
|
|
|
def extract_text_from_file(file): |
|
'''Extract text from uploaded file''' |
|
|
|
|
|
if file.type == "text/plain": |
|
|
|
stringio = StringIO(file.getvalue().decode("utf-8")) |
|
|
|
|
|
file_text = stringio.read() |
|
|
|
return file_text, None |
|
|
|
|
|
elif file.type == "application/pdf": |
|
pdfReader = PdfFileReader(file) |
|
count = pdfReader.numPages |
|
all_text = "" |
|
pdf_title = pdfReader.getDocumentInfo().title |
|
|
|
for i in range(count): |
|
|
|
try: |
|
page = pdfReader.getPage(i) |
|
all_text += page.extractText() |
|
|
|
except: |
|
continue |
|
|
|
file_text = all_text |
|
|
|
return file_text, pdf_title |
|
|
|
|
|
elif ( |
|
file.type |
|
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document" |
|
): |
|
file_text = docx2txt.process(file) |
|
|
|
return file_text, None |
|
|
|
|
|
def preprocess_plain_text(text, window_size=3): |
|
text = text.encode("ascii", "ignore").decode() |
|
text = re.sub(r"https*\S+", " ", text) |
|
text = re.sub(r"@\S+", " ", text) |
|
text = re.sub(r"#\S+", " ", text) |
|
text = re.sub(r"\s{2,}", " ", text) |
|
text = re.sub("[^.,!?%$A-Za-z0-9]+", " ", text) |
|
|
|
|
|
lines = [line.strip() for line in text.splitlines()] |
|
|
|
|
|
chunks = [phrase.strip() for line in lines for phrase in line.split(" ")] |
|
|
|
|
|
text = '\n'.join(chunk for chunk in chunks if chunk) |
|
|
|
|
|
paragraphs = [] |
|
for paragraph in text.replace('\n', ' ').split("\n\n"): |
|
if len(paragraph.strip()) > 0: |
|
paragraphs.append(sent_tokenize(paragraph.strip())) |
|
|
|
window_size = 3 |
|
passages = [] |
|
for paragraph in paragraphs: |
|
for start_idx in range(0, len(paragraph), window_size): |
|
end_idx = min(start_idx + window_size, len(paragraph)) |
|
passages.append(" ".join(paragraph[start_idx:end_idx])) |
|
|
|
return passages |
|
|
|
|
|
def bi_encode(bi_enc, passages): |
|
global bi_encoder |
|
|
|
bi_encoder = SentenceTransformer(bi_enc) |
|
|
|
|
|
|
|
|
|
|
|
with st.spinner('Encoding passages into a vector space...'): |
|
corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True) |
|
|
|
st.success(f"Embeddings computed.") |
|
|
|
return bi_encoder, corpus_embeddings |
|
|
|
|
|
def cross_encode(): |
|
global cross_encoder |
|
|
|
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2') |
|
return cross_encoder |
|
|
|
|
|
|
|
def display_as_table(model, score='score'): |
|
|
|
df = pd.DataFrame([(hit[score], passages[hit['corpus_id']]) for hit in model[0:2]], columns=['Score', 'Text']) |
|
df['Score'] = round(df['Score'], 2) |
|
|
|
return df |
|
|
|
|
|
|
|
|
|
st.title("Search with Retrieve & Rerank") |
|
window_size = 3 |
|
|
|
bi_encoder_type="multi-qa-mpnet-base-dot-v1" |
|
|
|
def search_func(query): |
|
global bi_encoder, cross_encoder |
|
|
|
st.subheader(f"Search Query: {query}") |
|
|
|
if url_text: |
|
|
|
st.write(f"Document Header: {title}") |
|
|
|
elif pdf_title: |
|
|
|
st.write(f"Document Header: {pdf_title}") |
|
|
|
|
|
question_embedding = bi_encoder.encode(query, convert_to_tensor=True) |
|
question_embedding = question_embedding.cpu() |
|
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=2, score_function=util.dot_score) |
|
hits = hits[0] |
|
|
|
|
|
cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits] |
|
cross_scores = cross_encoder.predict(cross_inp) |
|
|
|
|
|
for idx in range(len(cross_scores)): |
|
hits[idx]['cross-score'] = cross_scores[idx] |
|
|
|
|
|
st.markdown("\n-------------------------\n") |
|
st.subheader(f"Top 2 Bi-Encoder Retrieval hits") |
|
hits = sorted(hits, key=lambda x: x['score'], reverse=True) |
|
|
|
cross_df = display_as_table(hits, ) |
|
st.write(cross_df.to_html(index=False), unsafe_allow_html=True) |
|
|
|
|
|
st.markdown("\n-------------------------\n") |
|
st.subheader(f"Top-2 Cross-Encoder Re-ranker hits") |
|
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) |
|
|
|
rerank_df = display_as_table(hits, 'cross-score') |
|
st.write(rerank_df.to_html(index=False), unsafe_allow_html=True) |
|
|
|
|
|
def clear_text(): |
|
st.session_state["text_url"] = "" |
|
st.session_state["text_input"] = "" |
|
|
|
|
|
def clear_search_text(): |
|
st.session_state["text_input"] = "" |
|
|
|
|
|
url_text = st.text_input("Please Enter a url here",value="https://en.wikipedia.org/wiki/Virat_Kohli",key='text_url', on_change=clear_search_text) |
|
|
|
st.markdown( |
|
"<h3 style='text-align: center; color: red;'>OR</h3>", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
upload_doc = st.file_uploader("Upload a .txt, .pdf, .docx file", key="upload") |
|
|
|
search_query = st.text_input("Please Enter your search query here", |
|
value="Who is Virat Kohli?", key="text_input") |
|
|
|
if validators.url(url_text): |
|
|
|
title, text = extract_text_from_url(url_text) |
|
passages = preprocess_plain_text(text, window_size=3) |
|
|
|
elif upload_doc: |
|
|
|
text, pdf_title = extract_text_from_file(upload_doc) |
|
passages = preprocess_plain_text(text, window_size=3) |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
search = st.button("Search", key='search_but', help='Click to Search!!') |
|
|
|
with col2: |
|
clear = st.button("Clear Text Input", on_click=clear_text, key='clear',help='Click to clear the URL and query') |
|
|
|
if search: |
|
if bi_encoder_type: |
|
with st.spinner( |
|
text=f"Loading..........................." |
|
): |
|
bi_encoder, corpus_embeddings = bi_encode(bi_encoder_type, passages) |
|
cross_encoder = cross_encode() |
|
|
|
with st.spinner( |
|
text="Embedding completed, searching for relevant text for given query and hits..."): |
|
search_func(search_query) |
|
|
|
st.markdown(""" |
|
""") |
|
|