moritz648 commited on
Commit
6ee748f
·
1 Parent(s): 905edb5
Files changed (1) hide show
  1. app.py +303 -1
app.py CHANGED
@@ -5,9 +5,45 @@ from datasets import load_dataset
5
  import joblib
6
  from dataclasses import dataclass
7
  from enum import Enum
8
- from typing import Dict, List
9
  from dataclasses import dataclass
10
  from typing import Optional
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  @dataclass
@@ -115,6 +151,64 @@ def load_sciq(verbose: bool = False) -> IRDataset:
115
 
116
  # Assembly and return:
117
  return IRDataset(corpus=corpus, queries=queries, split2qrels=split2qrels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
 
120
  if __name__ == "__main__":
@@ -142,6 +236,214 @@ if __name__ == "__main__":
142
  # "|qrels-test|": 876
143
  # }
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  class Hit(TypedDict):
147
  cid: str
 
5
  import joblib
6
  from dataclasses import dataclass
7
  from enum import Enum
8
+ from typing import Dict, List, Type
9
  from dataclasses import dataclass
10
  from typing import Optional
11
+ from __future__ import annotations
12
+ from dataclasses import asdict, dataclass
13
+ import math
14
+ import os
15
+ from typing import Iterable, List, Optional, Type
16
+ import tqdm
17
+ from dataclasses import dataclass
18
+ import pickle
19
+ import os
20
+ from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar
21
+ from collections import Counter
22
+ import tqdm
23
+ import re
24
+ import nltk
25
+ from abc import ABC, abstractmethod
26
+ from typing import Any, Dict, Type
27
+ nltk.download("stopwords", quiet=True)
28
+ from nltk.corpus import stopwords as nltk_stopwords
29
+
30
+ class BaseRetriever(ABC):
31
+
32
+ @property
33
+ @abstractmethod
34
+ def index_class(self) -> Type[Any]:
35
+ pass
36
+
37
+ def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
38
+ raise NotImplementedError
39
+
40
+ @abstractmethod
41
+ def score(self, query: str, cid: str) -> float:
42
+ pass
43
+
44
+ @abstractmethod
45
+ def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
46
+ pass
47
 
48
 
49
  @dataclass
 
151
 
152
  # Assembly and return:
153
  return IRDataset(corpus=corpus, queries=queries, split2qrels=split2qrels)
154
+
155
+ class BaseInvertedIndexRetriever(BaseRetriever):
156
+
157
+ @property
158
+ @abstractmethod
159
+ def index_class(self) -> Type[InvertedIndex]:
160
+ pass
161
+
162
+ def __init__(self, index_dir: str) -> None:
163
+ self.index = self.index_class.from_saved(index_dir)
164
+
165
+ def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
166
+ toks = self.index.tokenize(query)
167
+ target_docid = self.index.cid2docid[cid]
168
+ term_weights = {}
169
+ for tok in toks:
170
+ if tok not in self.index.vocab:
171
+ continue
172
+ tid = self.index.vocab[tok]
173
+ posting_list = self.index.posting_lists[tid]
174
+ for docid, tweight in zip(
175
+ posting_list.docid_postings, posting_list.tweight_postings
176
+ ):
177
+ if docid == target_docid:
178
+ term_weights[tok] = tweight
179
+ break
180
+ return term_weights
181
+
182
+ def score(self, query: str, cid: str) -> float:
183
+ return sum(self.get_term_weights(query=query, cid=cid).values())
184
+
185
+ def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
186
+ toks = self.index.tokenize(query)
187
+ docid2score: Dict[int, float] = {}
188
+ for tok in toks:
189
+ if tok not in self.index.vocab:
190
+ continue
191
+ tid = self.index.vocab[tok]
192
+ posting_list = self.index.posting_lists[tid]
193
+ for docid, tweight in zip(
194
+ posting_list.docid_postings, posting_list.tweight_postings
195
+ ):
196
+ docid2score.setdefault(docid, 0)
197
+ docid2score[docid] += tweight
198
+ docid2score = dict(
199
+ sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk]
200
+ )
201
+ return {
202
+ self.index.collection_ids[docid]: score
203
+ for docid, score in docid2score.items()
204
+ }
205
+
206
+
207
+ class BM25Retriever(BaseInvertedIndexRetriever):
208
+
209
+ @property
210
+ def index_class(self) -> Type[BM25Index]:
211
+ return BM25Index
212
 
213
 
214
  if __name__ == "__main__":
 
236
  # "|qrels-test|": 876
237
  # }
238
 
239
+ LANGUAGE = "english"
240
+ word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
241
+ stopwords = set(nltk_stopwords.words(LANGUAGE))
242
+
243
+
244
+ def word_splitting(text: str) -> List[str]:
245
+ return word_splitter(text.lower())
246
+
247
+ def lemmatization(words: List[str]) -> List[str]:
248
+ return words # We ignore lemmatization here for simplicity
249
+
250
+ def simple_tokenize(text: str) -> List[str]:
251
+ words = word_splitting(text)
252
+ tokenized = list(filter(lambda w: w not in stopwords, words))
253
+ tokenized = lemmatization(tokenized)
254
+ return tokenized
255
+
256
+ T = TypeVar("T", bound="InvertedIndex")
257
+
258
+ @dataclass
259
+ class PostingList:
260
+ term: str # The term
261
+ docid_postings: List[int] # docid_postings[i] means the docid (int) of the i-th associated posting
262
+ tweight_postings: List[float] # tweight_postings[i] means the term weight (float) of the i-th associated posting
263
+
264
+
265
+ @dataclass
266
+ class InvertedIndex:
267
+ posting_lists: List[PostingList] # docid -> posting_list
268
+ vocab: Dict[str, int]
269
+ cid2docid: Dict[str, int] # collection_id -> docid
270
+ collection_ids: List[str] # docid -> collection_id
271
+ doc_texts: Optional[List[str]] = None # docid -> document text
272
+
273
+ def save(self, output_dir: str) -> None:
274
+ os.makedirs(output_dir, exist_ok=True)
275
+ with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
276
+ pickle.dump(self, f)
277
+
278
+ @classmethod
279
+ def from_saved(cls: Type[T], saved_dir: str) -> T:
280
+ index = cls(
281
+ posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
282
+ )
283
+ with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
284
+ index = pickle.load(f)
285
+ return index
286
+
287
+
288
+ # The output of the counting function:
289
+ @dataclass
290
+ class Counting:
291
+ posting_lists: List[PostingList]
292
+ vocab: Dict[str, int]
293
+ cid2docid: Dict[str, int]
294
+ collection_ids: List[str]
295
+ dfs: List[int] # tid -> df
296
+ dls: List[int] # docid -> doc length
297
+ avgdl: float
298
+ nterms: int
299
+ doc_texts: Optional[List[str]] = None
300
+
301
+ def run_counting(
302
+ documents: Iterable[Document],
303
+ tokenize_fn: Callable[[str], List[str]] = simple_tokenize,
304
+ store_raw: bool = True, # store the document text in doc_texts
305
+ ndocs: Optional[int] = None,
306
+ show_progress_bar: bool = True,
307
+ ) -> Counting:
308
+ """Counting TFs, DFs, doc_lengths, etc."""
309
+ posting_lists: List[PostingList] = []
310
+ vocab: Dict[str, int] = {}
311
+ cid2docid: Dict[str, int] = {}
312
+ collection_ids: List[str] = []
313
+ dfs: List[int] = [] # tid -> df
314
+ dls: List[int] = [] # docid -> doc length
315
+ nterms: int = 0
316
+ doc_texts: Optional[List[str]] = []
317
+ for doc in tqdm.tqdm(
318
+ documents,
319
+ desc="Counting",
320
+ total=ndocs,
321
+ disable=not show_progress_bar,
322
+ ):
323
+ if doc.collection_id in cid2docid:
324
+ continue
325
+ collection_ids.append(doc.collection_id)
326
+ docid = cid2docid.setdefault(doc.collection_id, len(cid2docid))
327
+ toks = tokenize_fn(doc.text)
328
+ tok2tf = Counter(toks)
329
+ dls.append(sum(tok2tf.values()))
330
+ for tok, tf in tok2tf.items():
331
+ nterms += tf
332
+ tid = vocab.get(tok, None)
333
+ if tid is None:
334
+ posting_lists.append(
335
+ PostingList(term=tok, docid_postings=[], tweight_postings=[])
336
+ )
337
+ tid = vocab.setdefault(tok, len(vocab))
338
+ posting_lists[tid].docid_postings.append(docid)
339
+ posting_lists[tid].tweight_postings.append(tf)
340
+ if tid < len(dfs):
341
+ dfs[tid] += 1
342
+ else:
343
+ dfs.append(0)
344
+ if store_raw:
345
+ doc_texts.append(doc.text)
346
+ else:
347
+ doc_texts = None
348
+ return Counting(
349
+ posting_lists=posting_lists,
350
+ vocab=vocab,
351
+ cid2docid=cid2docid,
352
+ collection_ids=collection_ids,
353
+ dfs=dfs,
354
+ dls=dls,
355
+ avgdl=sum(dls) / len(dls),
356
+ nterms=nterms,
357
+ doc_texts=doc_texts,
358
+ )
359
+
360
+
361
+ @dataclass
362
+ class BM25Index(InvertedIndex):
363
+
364
+ @staticmethod
365
+ def tokenize(text: str) -> List[str]:
366
+ return simple_tokenize(text)
367
+
368
+ @staticmethod
369
+ def cache_term_weights(
370
+ posting_lists: List[PostingList],
371
+ total_docs: int,
372
+ avgdl: float,
373
+ dfs: List[int],
374
+ dls: List[int],
375
+ k1: float,
376
+ b: float,
377
+ ) -> None:
378
+ """Compute term weights and caching"""
379
+
380
+ N = total_docs
381
+ for tid, posting_list in enumerate(
382
+ tqdm.tqdm(posting_lists, desc="Regularizing TFs")
383
+ ):
384
+ idf = BM25Index.calc_idf(df=dfs[tid], N=N)
385
+ for i in range(len(posting_list.docid_postings)):
386
+ docid = posting_list.docid_postings[i]
387
+ tf = posting_list.tweight_postings[i]
388
+ dl = dls[docid]
389
+ regularized_tf = BM25Index.calc_regularized_tf(
390
+ tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
391
+ )
392
+ posting_list.tweight_postings[i] = regularized_tf * idf
393
+
394
+ @staticmethod
395
+ def calc_regularized_tf(
396
+ tf: int, dl: float, avgdl: float, k1: float, b: float
397
+ ) -> float:
398
+ return tf / (tf + k1 * (1 - b + b * dl / avgdl))
399
+
400
+ @staticmethod
401
+ def calc_idf(df: int, N: int):
402
+ return math.log(1 + (N - df + 0.5) / (df + 0.5))
403
+
404
+ @classmethod
405
+ def build_from_documents(
406
+ cls: Type[BM25Index],
407
+ documents: Iterable[Document],
408
+ store_raw: bool = True,
409
+ output_dir: Optional[str] = None,
410
+ ndocs: Optional[int] = None,
411
+ show_progress_bar: bool = True,
412
+ k1: float = 0.9,
413
+ b: float = 0.4,
414
+ ) -> BM25Index:
415
+ # Counting TFs, DFs, doc_lengths, etc.:
416
+ counting = run_counting(
417
+ documents=documents,
418
+ tokenize_fn=BM25Index.tokenize,
419
+ store_raw=store_raw,
420
+ ndocs=ndocs,
421
+ show_progress_bar=show_progress_bar,
422
+ )
423
+
424
+ # Compute term weights and caching:
425
+ posting_lists = counting.posting_lists
426
+ total_docs = len(counting.cid2docid)
427
+ BM25Index.cache_term_weights(
428
+ posting_lists=posting_lists,
429
+ total_docs=total_docs,
430
+ avgdl=counting.avgdl,
431
+ dfs=counting.dfs,
432
+ dls=counting.dls,
433
+ k1=k1,
434
+ b=b,
435
+ )
436
+
437
+ # Assembly and save:
438
+ index = BM25Index(
439
+ posting_lists=posting_lists,
440
+ vocab=counting.vocab,
441
+ cid2docid=counting.cid2docid,
442
+ collection_ids=counting.collection_ids,
443
+ doc_texts=counting.doc_texts,
444
+ )
445
+ return index
446
+
447
 
448
  class Hit(TypedDict):
449
  cid: str