|
import collections |
|
import heapq |
|
import math |
|
import pickle |
|
import sys |
|
from numpy import inf |
|
import gradio as gr |
|
|
|
PARAM_K1 = 1.5 |
|
PARAM_B = 0.75 |
|
IDF_CUTOFF = -inf |
|
|
|
|
|
|
|
class BM25: |
|
"""Fast Implementation of Best Matching 25 ranking function. |
|
|
|
Attributes |
|
---------- |
|
t2d : <token: <doc, freq>> |
|
Dictionary with terms frequencies for each document in `corpus`. |
|
idf: <token, idf score> |
|
Pre computed IDF score for every term. |
|
doc_len : list of int |
|
List of document lengths. |
|
avgdl : float |
|
Average length of document in `corpus`. |
|
""" |
|
def __init__(self, corpus, k1=PARAM_K1, b=PARAM_B, alpha=IDF_CUTOFF): |
|
""" |
|
Parameters |
|
---------- |
|
corpus : list of list of str |
|
Given corpus. |
|
k1 : float |
|
Constant used for influencing the term frequency saturation. After saturation is reached, additional |
|
presence for the term adds a significantly less additional score. According to [1]_, experiments suggest |
|
that 1.2 < k1 < 2 yields reasonably good results, although the optimal value depends on factors such as |
|
the type of documents or queries. |
|
b : float |
|
Constant used for influencing the effects of different document lengths relative to average document length. |
|
When b is bigger, lengthier documents (compared to average) have more impact on its effect. According to |
|
[1]_, experiments suggest that 0.5 < b < 0.8 yields reasonably good results, although the optimal value |
|
depends on factors such as the type of documents or queries. |
|
alpha: float |
|
IDF cutoff, terms with a lower idf score than alpha will be dropped. A higher alpha will lower the accuracy |
|
of BM25 but increase performance |
|
""" |
|
self.k1 = k1 |
|
self.b = b |
|
self.alpha = alpha |
|
self.corpus = corpus |
|
|
|
self.avgdl = 0 |
|
self.t2d = {} |
|
self.idf = {} |
|
self.doc_len = [] |
|
if corpus: |
|
self._initialize(corpus) |
|
|
|
@property |
|
def corpus_size(self): |
|
return len(self.doc_len) |
|
|
|
def _initialize(self, corpus, progress=gr.Progress()): |
|
"""Calculates frequencies of terms in documents and in corpus. Also computes inverse document frequencies.""" |
|
i = 0 |
|
for document in progress.tqdm(corpus, desc = "Preparing search index", unit = "rows"): |
|
self.doc_len.append(len(document)) |
|
|
|
for word in document: |
|
if word not in self.t2d: |
|
self.t2d[word] = {} |
|
if i not in self.t2d[word]: |
|
self.t2d[word][i] = 0 |
|
self.t2d[word][i] += 1 |
|
i += 1 |
|
|
|
self.avgdl = sum(self.doc_len)/len(self.doc_len) |
|
to_delete = [] |
|
for word, docs in self.t2d.items(): |
|
idf = math.log(self.corpus_size - len(docs) + 0.5) - math.log(len(docs) + 0.5) |
|
|
|
if idf > self.alpha: |
|
self.idf[word] = idf |
|
else: |
|
to_delete.append(word) |
|
print(f"Dropping {len(to_delete)} terms") |
|
for word in to_delete: |
|
del self.t2d[word] |
|
|
|
if len(self.idf) == 0: |
|
print("Alpha value too high - all words removed from dataset.") |
|
self.average_idf = 0 |
|
|
|
else: |
|
self.average_idf = sum(self.idf.values())/len(self.idf) |
|
|
|
if self.average_idf < 0: |
|
print( |
|
f'Average inverse document frequency is less than zero. Your corpus of {self.corpus_size} documents' |
|
' is either too small or it does not originate from natural text. BM25 may produce' |
|
' unintuitive results.', |
|
file=sys.stderr |
|
) |
|
|
|
def get_top_n(self, query, documents, n=5): |
|
""" |
|
Retrieve the top n documents for the query. |
|
|
|
Parameters |
|
---------- |
|
query: list of str |
|
The tokenized query |
|
documents: list |
|
The documents to return from |
|
n: int |
|
The number of documents to return |
|
|
|
Returns |
|
------- |
|
list |
|
The top n documents |
|
""" |
|
assert self.corpus_size == len(documents), "The documents given don't match the index corpus!" |
|
scores = collections.defaultdict(float) |
|
for token in query: |
|
if token in self.t2d: |
|
for index, freq in self.t2d[token].items(): |
|
denom_cst = self.k1 * (1 - self.b + self.b * self.doc_len[index] / self.avgdl) |
|
scores[index] += self.idf[token]*freq*(self.k1 + 1)/(freq + denom_cst) |
|
|
|
return [documents[i] for i in heapq.nlargest(n, scores.keys(), key=scores.__getitem__)] |
|
|
|
|
|
def get_top_n_with_score(self, query, documents, n=5): |
|
""" |
|
Retrieve the top n documents for the query along with their scores. |
|
|
|
Parameters |
|
---------- |
|
query: list of str |
|
The tokenized query |
|
documents: list |
|
The documents to return from |
|
n: int |
|
The number of documents to return |
|
|
|
Returns |
|
------- |
|
list |
|
The top n documents along with their scores and row indices in the format (index, document, score) |
|
""" |
|
assert self.corpus_size == len(documents), "The documents given don't match the index corpus!" |
|
scores = collections.defaultdict(float) |
|
for token in query: |
|
if token in self.t2d: |
|
for index, freq in self.t2d[token].items(): |
|
denom_cst = self.k1 * (1 - self.b + self.b * self.doc_len[index] / self.avgdl) |
|
scores[index] += self.idf[token] * freq * (self.k1 + 1) / (freq + denom_cst) |
|
|
|
top_n_indices = heapq.nlargest(n, scores.keys(), key=scores.__getitem__) |
|
return [(i, documents[i], scores[i]) for i in top_n_indices] |
|
|
|
def extract_documents_and_scores(self, query, documents, n=5): |
|
""" |
|
Extract top n documents and their scores into separate lists. |
|
|
|
Parameters |
|
---------- |
|
query: list of str |
|
The tokenized query |
|
documents: list |
|
The documents to return from |
|
n: int |
|
The number of documents to return |
|
|
|
Returns |
|
------- |
|
tuple: (list, list) |
|
The first list contains the top n documents and the second list contains their scores. |
|
""" |
|
results = self.get_top_n_with_score(query, documents, n) |
|
try: |
|
indices, docs, scores = zip(*results) |
|
except: |
|
print("No search results returned") |
|
return [], [], [] |
|
return list(indices), docs, list(scores) |
|
|
|
def save(self, filename): |
|
with open(f"{filename}.pkl", "wb") as fsave: |
|
pickle.dump(self, fsave, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
|
@staticmethod |
|
def load(filename): |
|
with open(f"{filename}.pkl", "rb") as fsave: |
|
return pickle.load(fsave) |
|
|