VQArt / src /models /search_engine.py
nouman-10's picture
Upload 35 files
ccba2d5
raw
history blame
4.44 kB
import os
from tqdm import tqdm
from whoosh.index import * # whoosh: full-text indexing and searching
from whoosh.fields import *
from whoosh import qparser
import sys
sys.path.insert(1, "src/data")
#import src.data.wiki_scrape as wiki_scrape
import wiki_scrape as wiki_scrape
class IR(object):
def __init__(self,
max_passage_length = 800,
overlap = 0.4,
passages_limit = 10000,
data_path = 'data/wiki_articles',
index_path = 'index'):
self.max_passage_length = max_passage_length
self.overlap = overlap
self.passages_limit = passages_limit
self.data_path = data_path
self.index_path = index_path
self.ix = None
passages = self.__load_passages()
if os.path.exists(self.index_path):
print(f'Index already exists in the directory {self.index_path}')
print('Skipping building the index...')
self.ix = open_dir(self.index_path)
else:
self.__create_index(passages)
def __create_passages_from_article(self, content):
passages = []
passage_diff = int(self.max_passage_length * (1-self.overlap))
for i in range(0, len(content), passage_diff):
passages.append(content[i: i + self.max_passage_length])
return passages
def __scrape_wiki_if_not_exists(self):
if not os.path.exists(self.data_path):
os.makedirs(self.data_path)
if len(os.listdir(self.data_path)) == 0:
print('No Wiki articles. Scraping...')
wiki_scrape.scrape('src/data/entities.txt', 'data/wiki_articles')
def __load_passages(self):
self.__scrape_wiki_if_not_exists()
passages = []
count = 0
directory = os.fsencode(self.data_path)
for file in os.listdir(directory):
filename = os.fsdecode(file)
if not filename.endswith(".txt"):
continue
with open(os.path.join(self.data_path, filename), 'r', encoding='utf-8') as f:
content = f.read()
article_passages = self.__create_passages_from_article(content)
#print(f'Created {len(article_passages)} passages')
passages.extend(article_passages)
count += 1
if count == self.passages_limit:
break
return passages
def __create_index(self, passages):
# Create the index directory
os.mkdir(self.index_path)
# Schema definition:
# - id: type ID, unique, stored; doc id in order given the passages file
# - text: type TEXT processed by StemmingAnalyzer; not stored; content of the passage
schema = Schema(id = ID(stored=True,unique=True),
text = TEXT(analyzer=analysis.StemmingAnalyzer())
)
# Create an index
self.ix = create_in("index", schema)
writer = self.ix.writer() #run once! or restart runtime
# Add papers to the index, iterating through each row in the metadata dataframe
id = 0
for passage_text in tqdm(passages, desc='Building index'):
writer.add_document(id=str(id),text=passage_text)
id += 1
# Save the added documents
writer.commit()
print("Index successfully created")
def retrieve_documents(self, query, topk):
scores=[]
text=[]
passages = self.__load_passages()
# Open the searcher for reading the index. The default BM25 algorithm will be used for scoring
with self.ix.searcher() as searcher:
searcher = self.ix.searcher()
# Define the query parser ('text' will be the default field to search), and set the input query
q = qparser.QueryParser("text", self.ix.schema, group=qparser.OrGroup).parse(query)
# Search using the query q, and get the n_docs documents, sorted with the highest-scoring documents first
results = searcher.search(q, limit=topk)
# results is a list of dictionaries where each dictionary is the stored fields of the document
# Iterate over the retrieved documents
for hit in results:
scores.append(hit.score)
text.append(passages[int(hit['id'])])
return text, scores