Kurt commited on
Commit
b7fb5a8
·
1 Parent(s): 2b4cfa7
Files changed (1) hide show
  1. app.py +333 -7
app.py CHANGED
@@ -20,6 +20,10 @@ from nlp4web_codebase.nlp4web_codebase.ir.data_loaders.sciq import load_sciq
20
  from nlp4web_codebase.nlp4web_codebase.ir.models import BaseRetriever
21
  from typing import Type
22
  from abc import abstractmethod
 
 
 
 
23
 
24
  LANGUAGE = "english"
25
  word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
@@ -308,6 +312,326 @@ class BM25Retriever(BaseInvertedIndexRetriever):
308
  bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
309
  bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?")
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  class Hit(TypedDict):
312
  cid: str
313
  score: float
@@ -317,21 +641,23 @@ demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
317
  return_type = List[Hit]
318
 
319
  ## YOUR_CODE_STARTS_HERE
320
- bm25_index = BM25Index.build_from_documents(
321
  documents=iter(sciq.corpus),
322
  ndocs=12160,
323
- show_progress_bar=True
 
 
324
  )
325
- bm25_index.save("output/bm25_index")
326
 
327
  def search(query: str) -> List[Hit]:
328
- bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
329
- result = bm25_retriever.retrieve(query)
330
 
331
  l : return_type = []
332
  for cid, score in result.items():
333
- docid = bm25_retriever.index.cid2docid[cid]
334
- text = bm25_retriever.index.doc_texts[docid]
335
 
336
  l.append(Hit(cid=cid, score=score, text=text))
337
 
 
20
  from nlp4web_codebase.nlp4web_codebase.ir.models import BaseRetriever
21
  from typing import Type
22
  from abc import abstractmethod
23
+ from nlp4web_codebase.nlp4web_codebase.ir.data_loaders import Split
24
+ import pytrec_eval
25
+ import numpy as np
26
+ from scipy.sparse._csc import csc_matrix
27
 
28
  LANGUAGE = "english"
29
  word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
 
312
  bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
313
  bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?")
314
 
315
+ def evaluate_map(rankings: Dict[str, Dict[str, float]], split=Split.dev) -> float:
316
+ metric = "map_cut_10"
317
+ qrels = sciq.get_qrels_dict(split)
318
+ evaluator = pytrec_eval.RelevanceEvaluator(sciq.get_qrels_dict(split), (metric,))
319
+ qps = evaluator.evaluate(rankings)
320
+ return float(np.mean([qp[metric] for qp in qps.values()]))
321
+
322
+ """Example of using the pre-requisite code:"""
323
+
324
+ # Loading dataset:
325
+ from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
326
+ sciq = load_sciq()
327
+ counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
328
+
329
+ # Building BM25 index and save:
330
+ bm25_index = BM25Index.build_from_documents(
331
+ documents=iter(sciq.corpus),
332
+ ndocs=12160,
333
+ show_progress_bar=True
334
+ )
335
+ bm25_index.save("output/bm25_index")
336
+
337
+ # Loading index and use BM25 retriever to retrieve:
338
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
339
+ print(bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?")) # the ranking
340
+
341
+ plots_b: Dict[str, List[float]] = {
342
+ "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
343
+ "Y": []
344
+ }
345
+ plots_k1: Dict[str, List[float]] = {
346
+ "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
347
+ "Y": []
348
+ }
349
+
350
+ ## YOUR_CODE_STARTS_HERE
351
+ # Two steps should be involved:
352
+ # Step 1. Fix k1 value to the default one 0.9,
353
+ # go through all the candidate b values (0, 0.1, ..., 1.0),
354
+ # and record in plots_b["Y"] the corresponding performances obtained via evaluate_map;
355
+ # Step 2. Fix b to the best one in step 1. and do the same for k1.
356
+
357
+ # Hint (on using the pre-requisite code):
358
+ # - One can use the loaded sciq dataset directly (loaded in the pre-requisite code);
359
+ # - One can build bm25_index with `BM25Index.build_from_documents`;
360
+ # - One can use BM25Retriever to load the index and perform retrieval on the dev queries
361
+ # (dev queries can be obtained via sciq.get_split_queries(Split.dev))
362
+
363
+
364
+ k1 = 0.9
365
+ b_list = []
366
+ for b in plots_b["X"]:
367
+ bm25_index = BM25Index.build_from_documents(
368
+ documents=iter(sciq.corpus),
369
+ ndocs=12160,
370
+ show_progress_bar=True,
371
+ k1=k1,
372
+ b=b
373
+ )
374
+ bm25_index.save("output/bm25_index")
375
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
376
+ rankings = {}
377
+ for query in sciq.get_split_queries(Split.dev):
378
+ ranking = bm25_retriever.retrieve(query=query.text)
379
+ rankings[query.query_id] = ranking
380
+ optimized_map = evaluate_map(rankings, split=Split.dev)
381
+ b_list.append(optimized_map)
382
+
383
+ plots_b["Y"] = b_list
384
+
385
+ b = plots_b["X"][np.argmax(plots_b["Y"])]
386
+ k1_list = []
387
+ for k1 in plots_k1["X"]:
388
+ bm25_index = BM25Index.build_from_documents(
389
+ documents=iter(sciq.corpus),
390
+ ndocs=12160,
391
+ show_progress_bar=True,
392
+ k1=k1,
393
+ b=b
394
+ )
395
+ bm25_index.save("output/bm25_index")
396
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
397
+ rankings = {}
398
+ for query in sciq.get_split_queries(Split.dev):
399
+ ranking = bm25_retriever.retrieve(query=query.text)
400
+ rankings[query.query_id] = ranking
401
+ optimized_map = evaluate_map(rankings, split=Split.dev)
402
+ k1_list.append(optimized_map)
403
+
404
+ plots_k1["Y"] = k1_list
405
+
406
+
407
+
408
+ """Let's check the effectiveness gain on test after this tuning on dev"""
409
+
410
+ default_map = 0.7849
411
+ best_b = plots_b["X"][np.argmax(plots_b["Y"])]
412
+ best_k1 = plots_k1["X"][np.argmax(plots_k1["Y"])]
413
+ bm25_index = BM25Index.build_from_documents(
414
+ documents=iter(sciq.corpus),
415
+ ndocs=12160,
416
+ show_progress_bar=True,
417
+ k1=best_k1,
418
+ b=best_b
419
+ )
420
+ bm25_index.save("output/bm25_index")
421
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
422
+ rankings = {}
423
+ for query in sciq.get_split_queries(Split.test): # note this is now on test
424
+ ranking = bm25_retriever.retrieve(query=query.text)
425
+ rankings[query.query_id] = ranking
426
+ optimized_map = evaluate_map(rankings, split=Split.test) # note this is now on test
427
+ print(default_map, optimized_map)
428
+
429
+ """## TASK2.2: implement `CSCBM25Index` (4 points)
430
+
431
+ Implement `CSCBM25Index` by completing the missing code. Note that `CSCInvertedIndex` is similar to `InvertedIndex` which we talked about during the class. The main difference is posting lists are represented by a CSC sparse matrix.
432
+ """
433
+
434
+ @dataclass
435
+ class CSCInvertedIndex:
436
+ posting_lists_matrix: csc_matrix # docid -> posting_list
437
+ vocab: Dict[str, int]
438
+ cid2docid: Dict[str, int] # collection_id -> docid
439
+ collection_ids: List[str] # docid -> collection_id
440
+ doc_texts: Optional[List[str]] = None # docid -> document text
441
+
442
+ def save(self, output_dir: str) -> None:
443
+ os.makedirs(output_dir, exist_ok=True)
444
+ with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
445
+ pickle.dump(self, f)
446
+
447
+ @classmethod
448
+ def from_saved(cls: Type[T], saved_dir: str) -> T:
449
+ index = cls(
450
+ posting_lists_matrix=None, vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
451
+ )
452
+ with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
453
+ index = pickle.load(f)
454
+ return index
455
+
456
+ @dataclass
457
+ class CSCBM25Index(CSCInvertedIndex):
458
+
459
+ @staticmethod
460
+ def tokenize(text: str) -> List[str]:
461
+ return simple_tokenize(text)
462
+
463
+ @staticmethod
464
+ def cache_term_weights(
465
+ posting_lists: List[PostingList],
466
+ total_docs: int,
467
+ avgdl: float,
468
+ dfs: List[int],
469
+ dls: List[int],
470
+ k1: float,
471
+ b: float,
472
+ ) -> csc_matrix:
473
+ """Compute term weights and caching"""
474
+
475
+ ## YOUR_CODE_STARTS_HERE
476
+ data = []
477
+ indices = []
478
+ indptr = []
479
+
480
+ N = total_docs
481
+ for tid, posting_list in enumerate(
482
+ tqdm.tqdm(posting_lists, desc="Regularizing TFs")
483
+ ):
484
+ if indptr == []:
485
+ indptr.append(0)
486
+ #if dfs[tid] != len(posting_list.docid_postings):
487
+ # print(dfs[tid], ", ", len(posting_list.docid_postings))
488
+ #if dfs[tid] == 0:
489
+ # print(posting_list.docid_postings[0])
490
+ indptr.append(indptr[-1] + len(posting_list.docid_postings))
491
+ idf = CSCBM25Index.calc_idf(df=dfs[tid], N=N)
492
+ for i in range(len(posting_list.docid_postings)):
493
+ docid = posting_list.docid_postings[i]
494
+ indices.append(docid)
495
+
496
+ tf = posting_list.tweight_postings[i]
497
+ dl = dls[docid]
498
+ regularized_tf = CSCBM25Index.calc_regularized_tf(
499
+ tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
500
+ )
501
+
502
+ tf_idf = regularized_tf * idf
503
+ data.append(tf_idf)
504
+
505
+ posting_lists_matrix = csc_matrix((data, indices, indptr)).astype(np.float32)
506
+ print(posting_lists_matrix.shape)
507
+ return posting_lists_matrix
508
+ ## YOUR_CODE_ENDS_HERE
509
+
510
+ @staticmethod
511
+ def calc_regularized_tf(
512
+ tf: int, dl: float, avgdl: float, k1: float, b: float
513
+ ) -> float:
514
+ return tf / (tf + k1 * (1 - b + b * dl / avgdl))
515
+
516
+ @staticmethod
517
+ def calc_idf(df: int, N: int):
518
+ return math.log(1 + (N - df + 0.5) / (df + 0.5))
519
+
520
+ @classmethod
521
+ def build_from_documents(
522
+ cls: Type[CSCBM25Index],
523
+ documents: Iterable[Document],
524
+ store_raw: bool = True,
525
+ output_dir: Optional[str] = None,
526
+ ndocs: Optional[int] = None,
527
+ show_progress_bar: bool = True,
528
+ k1: float = 0.9,
529
+ b: float = 0.4,
530
+ ) -> CSCBM25Index:
531
+ # Counting TFs, DFs, doc_lengths, etc.:
532
+ counting = run_counting(
533
+ documents=documents,
534
+ tokenize_fn=CSCBM25Index.tokenize,
535
+ store_raw=store_raw,
536
+ ndocs=ndocs,
537
+ show_progress_bar=show_progress_bar,
538
+ )
539
+
540
+ # Compute term weights and caching:
541
+ posting_lists = counting.posting_lists
542
+ total_docs = len(counting.cid2docid)
543
+ posting_lists_matrix = CSCBM25Index.cache_term_weights(
544
+ posting_lists=posting_lists,
545
+ total_docs=total_docs,
546
+ avgdl=counting.avgdl,
547
+ dfs=counting.dfs,
548
+ dls=counting.dls,
549
+ k1=k1,
550
+ b=b,
551
+ )
552
+
553
+ # Assembly and save:
554
+ index = CSCBM25Index(
555
+ posting_lists_matrix=posting_lists_matrix,
556
+ vocab=counting.vocab,
557
+ cid2docid=counting.cid2docid,
558
+ collection_ids=counting.collection_ids,
559
+ doc_texts=counting.doc_texts,
560
+ )
561
+ return index
562
+
563
+ csc_bm25_index = CSCBM25Index.build_from_documents(
564
+ documents=iter(sciq.corpus),
565
+ ndocs=12160,
566
+ show_progress_bar=True,
567
+ k1=best_k1,
568
+ b=best_b
569
+ )
570
+ csc_bm25_index.save("output/csc_bm25_index")
571
+
572
+ class BaseCSCInvertedIndexRetriever(BaseRetriever):
573
+
574
+ @property
575
+ @abstractmethod
576
+ def index_class(self) -> Type[CSCInvertedIndex]:
577
+ pass
578
+
579
+ def __init__(self, index_dir: str) -> None:
580
+ self.index = self.index_class.from_saved(index_dir)
581
+
582
+ def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
583
+ ## YOUR_CODE_STARTS_HERE
584
+ toks = CSCBM25Index.tokenize(query)
585
+ target_docid = self.index.cid2docid[cid]
586
+ term_weights = {}
587
+ for tok in toks:
588
+ if tok not in self.index.vocab:
589
+ continue
590
+ tid = self.index.vocab[tok]
591
+ weight = self.index.posting_lists_matrix[target_docid, tid]
592
+ if weight == 0:
593
+ continue
594
+ term_weights[tok] = weight
595
+ return term_weights
596
+ ## YOUR_CODE_ENDS_HERE
597
+
598
+ def score(self, query: str, cid: str) -> float:
599
+ return sum(self.get_term_weights(query=query, cid=cid).values())
600
+
601
+ def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
602
+ ## YOUR_CODE_STARTS_HERE
603
+ toks = CSCBM25Index.tokenize(query)
604
+ docid2score: Dict[int, float] = {}
605
+ for tok in toks:
606
+ if tok not in self.index.vocab:
607
+ continue
608
+ tid = self.index.vocab[tok]
609
+ posting_list = self.index.posting_lists_matrix.getcol(tid)
610
+ indices = posting_list.indices
611
+ weights = posting_list.data
612
+ for docid, tweight in zip(indices, weights):
613
+ docid2score.setdefault(docid, 0)
614
+ docid2score[docid] += tweight
615
+
616
+
617
+ docid2score = dict(
618
+ sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk]
619
+ )
620
+ ranking = {
621
+ self.index.collection_ids[docid]: score
622
+ for docid, score in docid2score.items()
623
+ }
624
+ return ranking
625
+ ## YOUR_CODE_ENDS_HERE
626
+
627
+
628
+ class CSCBM25Retriever(BaseCSCInvertedIndexRetriever):
629
+
630
+ @property
631
+ def index_class(self) -> Type[CSCBM25Index]:
632
+ return CSCBM25Index
633
+
634
+
635
  class Hit(TypedDict):
636
  cid: str
637
  score: float
 
641
  return_type = List[Hit]
642
 
643
  ## YOUR_CODE_STARTS_HERE
644
+ csc_bm25_index = CSCBM25Index.build_from_documents(
645
  documents=iter(sciq.corpus),
646
  ndocs=12160,
647
+ show_progress_bar=True,
648
+ k1=best_k1,
649
+ b=best_b
650
  )
651
+ csc_bm25_index.save("output/csc_bm25_index")
652
 
653
  def search(query: str) -> List[Hit]:
654
+ csc_bm25_index = CSCBM25Retriever(index_dir="output/csc_bm25_index")
655
+ result = csc_bm25_index.retrieve(query)
656
 
657
  l : return_type = []
658
  for cid, score in result.items():
659
+ docid = csc_bm25_index.index.cid2docid[cid]
660
+ text = csc_bm25_index.index.doc_texts[docid]
661
 
662
  l.append(Hit(cid=cid, score=score, text=text))
663