Sean-Case
Initial commit
a9c2120
raw
history blame
5.99 kB
import collections
import heapq
import math
import pickle
import sys
from numpy import inf
import gradio as gr
PARAM_K1 = 1.5
PARAM_B = 0.75
IDF_CUTOFF = -inf
# Built off https://github.com/Inspirateur/Fast-BM25
class BM25:
"""Fast Implementation of Best Matching 25 ranking function.
Attributes
----------
t2d : <token: <doc, freq>>
Dictionary with terms frequencies for each document in `corpus`.
idf: <token, idf score>
Pre computed IDF score for every term.
doc_len : list of int
List of document lengths.
avgdl : float
Average length of document in `corpus`.
"""
def __init__(self, corpus, k1=PARAM_K1, b=PARAM_B, alpha=IDF_CUTOFF):
"""
Parameters
----------
corpus : list of list of str
Given corpus.
k1 : float
Constant used for influencing the term frequency saturation. After saturation is reached, additional
presence for the term adds a significantly less additional score. According to [1]_, experiments suggest
that 1.2 < k1 < 2 yields reasonably good results, although the optimal value depends on factors such as
the type of documents or queries.
b : float
Constant used for influencing the effects of different document lengths relative to average document length.
When b is bigger, lengthier documents (compared to average) have more impact on its effect. According to
[1]_, experiments suggest that 0.5 < b < 0.8 yields reasonably good results, although the optimal value
depends on factors such as the type of documents or queries.
alpha: float
IDF cutoff, terms with a lower idf score than alpha will be dropped. A higher alpha will lower the accuracy
of BM25 but increase performance
"""
self.k1 = k1
self.b = b
self.alpha = alpha
self.corpus = corpus
self.avgdl = 0
self.t2d = {}
self.idf = {}
self.doc_len = []
if corpus:
self._initialize(corpus)
@property
def corpus_size(self):
return len(self.doc_len)
def _initialize(self, corpus, progress=gr.Progress()):
"""Calculates frequencies of terms in documents and in corpus. Also computes inverse document frequencies."""
i = 0
for document in progress.tqdm(corpus, desc = "Preparing search index", unit = "rows"):
self.doc_len.append(len(document))
for word in document:
if word not in self.t2d:
self.t2d[word] = {}
if i not in self.t2d[word]:
self.t2d[word][i] = 0
self.t2d[word][i] += 1
i += 1
self.avgdl = sum(self.doc_len)/len(self.doc_len)
to_delete = []
for word, docs in self.t2d.items():
idf = math.log(self.corpus_size - len(docs) + 0.5) - math.log(len(docs) + 0.5)
# only store the idf score if it's above the threshold
if idf > self.alpha:
self.idf[word] = idf
else:
to_delete.append(word)
print(f"Dropping {len(to_delete)} terms")
for word in to_delete:
del self.t2d[word]
if len(self.idf) == 0:
print("Alpha value too high - all words removed from dataset.")
self.average_idf = 0
else:
self.average_idf = sum(self.idf.values())/len(self.idf)
if self.average_idf < 0:
print(
f'Average inverse document frequency is less than zero. Your corpus of {self.corpus_size} documents'
' is either too small or it does not originate from natural text. BM25 may produce'
' unintuitive results.',
file=sys.stderr
)
def get_top_n(self, query, documents, n=5):
"""
Retrieve the top n documents for the query.
Parameters
----------
query: list of str
The tokenized query
documents: list
The documents to return from
n: int
The number of documents to return
Returns
-------
list
The top n documents
"""
assert self.corpus_size == len(documents), "The documents given don't match the index corpus!"
scores = collections.defaultdict(float)
for token in query:
if token in self.t2d:
for index, freq in self.t2d[token].items():
denom_cst = self.k1 * (1 - self.b + self.b * self.doc_len[index] / self.avgdl)
scores[index] += self.idf[token]*freq*(self.k1 + 1)/(freq + denom_cst)
return [documents[i] for i in heapq.nlargest(n, scores.keys(), key=scores.__getitem__)]
def get_top_n_with_score(self, query, documents, n=5):
"""
Retrieve the top n documents for the query along with their scores.
Parameters
----------
query: list of str
The tokenized query
documents: list
The documents to return from
n: int
The number of documents to return
Returns
-------
list
The top n documents along with their scores and row indices in the format (index, document, score)
"""
assert self.corpus_size == len(documents), "The documents given don't match the index corpus!"
scores = collections.defaultdict(float)
for token in query:
if token in self.t2d:
for index, freq in self.t2d[token].items():
denom_cst = self.k1 * (1 - self.b + self.b * self.doc_len[index] / self.avgdl)
scores[index] += self.idf[token] * freq * (self.k1 + 1) / (freq + denom_cst)
top_n_indices = heapq.nlargest(n, scores.keys(), key=scores.__getitem__)
return [(i, documents[i], scores[i]) for i in top_n_indices]
def extract_documents_and_scores(self, query, documents, n=5):
"""
Extract top n documents and their scores into separate lists.
Parameters
----------
query: list of str
The tokenized query
documents: list
The documents to return from
n: int
The number of documents to return
Returns
-------
tuple: (list, list)
The first list contains the top n documents and the second list contains their scores.
"""
results = self.get_top_n_with_score(query, documents, n)
try:
indices, docs, scores = zip(*results)
except:
print("No search results returned")
return [], [], []
return list(indices), docs, list(scores)
def save(self, filename):
with open(f"{filename}.pkl", "wb") as fsave:
pickle.dump(self, fsave, protocol=pickle.HIGHEST_PROTOCOL)
@staticmethod
def load(filename):
with open(f"{filename}.pkl", "rb") as fsave:
return pickle.load(fsave)