File size: 5,994 Bytes
a9c2120 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import collections
import heapq
import math
import pickle
import sys
from numpy import inf
import gradio as gr
PARAM_K1 = 1.5
PARAM_B = 0.75
IDF_CUTOFF = -inf
# Built off https://github.com/Inspirateur/Fast-BM25
class BM25:
"""Fast Implementation of Best Matching 25 ranking function.
Attributes
----------
t2d : <token: <doc, freq>>
Dictionary with terms frequencies for each document in `corpus`.
idf: <token, idf score>
Pre computed IDF score for every term.
doc_len : list of int
List of document lengths.
avgdl : float
Average length of document in `corpus`.
"""
def __init__(self, corpus, k1=PARAM_K1, b=PARAM_B, alpha=IDF_CUTOFF):
"""
Parameters
----------
corpus : list of list of str
Given corpus.
k1 : float
Constant used for influencing the term frequency saturation. After saturation is reached, additional
presence for the term adds a significantly less additional score. According to [1]_, experiments suggest
that 1.2 < k1 < 2 yields reasonably good results, although the optimal value depends on factors such as
the type of documents or queries.
b : float
Constant used for influencing the effects of different document lengths relative to average document length.
When b is bigger, lengthier documents (compared to average) have more impact on its effect. According to
[1]_, experiments suggest that 0.5 < b < 0.8 yields reasonably good results, although the optimal value
depends on factors such as the type of documents or queries.
alpha: float
IDF cutoff, terms with a lower idf score than alpha will be dropped. A higher alpha will lower the accuracy
of BM25 but increase performance
"""
self.k1 = k1
self.b = b
self.alpha = alpha
self.corpus = corpus
self.avgdl = 0
self.t2d = {}
self.idf = {}
self.doc_len = []
if corpus:
self._initialize(corpus)
@property
def corpus_size(self):
return len(self.doc_len)
def _initialize(self, corpus, progress=gr.Progress()):
"""Calculates frequencies of terms in documents and in corpus. Also computes inverse document frequencies."""
i = 0
for document in progress.tqdm(corpus, desc = "Preparing search index", unit = "rows"):
self.doc_len.append(len(document))
for word in document:
if word not in self.t2d:
self.t2d[word] = {}
if i not in self.t2d[word]:
self.t2d[word][i] = 0
self.t2d[word][i] += 1
i += 1
self.avgdl = sum(self.doc_len)/len(self.doc_len)
to_delete = []
for word, docs in self.t2d.items():
idf = math.log(self.corpus_size - len(docs) + 0.5) - math.log(len(docs) + 0.5)
# only store the idf score if it's above the threshold
if idf > self.alpha:
self.idf[word] = idf
else:
to_delete.append(word)
print(f"Dropping {len(to_delete)} terms")
for word in to_delete:
del self.t2d[word]
if len(self.idf) == 0:
print("Alpha value too high - all words removed from dataset.")
self.average_idf = 0
else:
self.average_idf = sum(self.idf.values())/len(self.idf)
if self.average_idf < 0:
print(
f'Average inverse document frequency is less than zero. Your corpus of {self.corpus_size} documents'
' is either too small or it does not originate from natural text. BM25 may produce'
' unintuitive results.',
file=sys.stderr
)
def get_top_n(self, query, documents, n=5):
"""
Retrieve the top n documents for the query.
Parameters
----------
query: list of str
The tokenized query
documents: list
The documents to return from
n: int
The number of documents to return
Returns
-------
list
The top n documents
"""
assert self.corpus_size == len(documents), "The documents given don't match the index corpus!"
scores = collections.defaultdict(float)
for token in query:
if token in self.t2d:
for index, freq in self.t2d[token].items():
denom_cst = self.k1 * (1 - self.b + self.b * self.doc_len[index] / self.avgdl)
scores[index] += self.idf[token]*freq*(self.k1 + 1)/(freq + denom_cst)
return [documents[i] for i in heapq.nlargest(n, scores.keys(), key=scores.__getitem__)]
def get_top_n_with_score(self, query, documents, n=5):
"""
Retrieve the top n documents for the query along with their scores.
Parameters
----------
query: list of str
The tokenized query
documents: list
The documents to return from
n: int
The number of documents to return
Returns
-------
list
The top n documents along with their scores and row indices in the format (index, document, score)
"""
assert self.corpus_size == len(documents), "The documents given don't match the index corpus!"
scores = collections.defaultdict(float)
for token in query:
if token in self.t2d:
for index, freq in self.t2d[token].items():
denom_cst = self.k1 * (1 - self.b + self.b * self.doc_len[index] / self.avgdl)
scores[index] += self.idf[token] * freq * (self.k1 + 1) / (freq + denom_cst)
top_n_indices = heapq.nlargest(n, scores.keys(), key=scores.__getitem__)
return [(i, documents[i], scores[i]) for i in top_n_indices]
def extract_documents_and_scores(self, query, documents, n=5):
"""
Extract top n documents and their scores into separate lists.
Parameters
----------
query: list of str
The tokenized query
documents: list
The documents to return from
n: int
The number of documents to return
Returns
-------
tuple: (list, list)
The first list contains the top n documents and the second list contains their scores.
"""
results = self.get_top_n_with_score(query, documents, n)
try:
indices, docs, scores = zip(*results)
except:
print("No search results returned")
return [], [], []
return list(indices), docs, list(scores)
def save(self, filename):
with open(f"{filename}.pkl", "wb") as fsave:
pickle.dump(self, fsave, protocol=pickle.HIGHEST_PROTOCOL)
@staticmethod
def load(filename):
with open(f"{filename}.pkl", "rb") as fsave:
return pickle.load(fsave)
|