Spaces:
Sleeping
Sleeping
Kurt
commited on
Commit
·
b7fb5a8
1
Parent(s):
2b4cfa7
cool2
Browse files
app.py
CHANGED
@@ -20,6 +20,10 @@ from nlp4web_codebase.nlp4web_codebase.ir.data_loaders.sciq import load_sciq
|
|
20 |
from nlp4web_codebase.nlp4web_codebase.ir.models import BaseRetriever
|
21 |
from typing import Type
|
22 |
from abc import abstractmethod
|
|
|
|
|
|
|
|
|
23 |
|
24 |
LANGUAGE = "english"
|
25 |
word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
|
@@ -308,6 +312,326 @@ class BM25Retriever(BaseInvertedIndexRetriever):
|
|
308 |
bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
|
309 |
bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?")
|
310 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
class Hit(TypedDict):
|
312 |
cid: str
|
313 |
score: float
|
@@ -317,21 +641,23 @@ demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
|
|
317 |
return_type = List[Hit]
|
318 |
|
319 |
## YOUR_CODE_STARTS_HERE
|
320 |
-
|
321 |
documents=iter(sciq.corpus),
|
322 |
ndocs=12160,
|
323 |
-
show_progress_bar=True
|
|
|
|
|
324 |
)
|
325 |
-
|
326 |
|
327 |
def search(query: str) -> List[Hit]:
|
328 |
-
|
329 |
-
result =
|
330 |
|
331 |
l : return_type = []
|
332 |
for cid, score in result.items():
|
333 |
-
docid =
|
334 |
-
text =
|
335 |
|
336 |
l.append(Hit(cid=cid, score=score, text=text))
|
337 |
|
|
|
20 |
from nlp4web_codebase.nlp4web_codebase.ir.models import BaseRetriever
|
21 |
from typing import Type
|
22 |
from abc import abstractmethod
|
23 |
+
from nlp4web_codebase.nlp4web_codebase.ir.data_loaders import Split
|
24 |
+
import pytrec_eval
|
25 |
+
import numpy as np
|
26 |
+
from scipy.sparse._csc import csc_matrix
|
27 |
|
28 |
LANGUAGE = "english"
|
29 |
word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
|
|
|
312 |
bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
|
313 |
bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?")
|
314 |
|
315 |
+
def evaluate_map(rankings: Dict[str, Dict[str, float]], split=Split.dev) -> float:
|
316 |
+
metric = "map_cut_10"
|
317 |
+
qrels = sciq.get_qrels_dict(split)
|
318 |
+
evaluator = pytrec_eval.RelevanceEvaluator(sciq.get_qrels_dict(split), (metric,))
|
319 |
+
qps = evaluator.evaluate(rankings)
|
320 |
+
return float(np.mean([qp[metric] for qp in qps.values()]))
|
321 |
+
|
322 |
+
"""Example of using the pre-requisite code:"""
|
323 |
+
|
324 |
+
# Loading dataset:
|
325 |
+
from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
|
326 |
+
sciq = load_sciq()
|
327 |
+
counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
|
328 |
+
|
329 |
+
# Building BM25 index and save:
|
330 |
+
bm25_index = BM25Index.build_from_documents(
|
331 |
+
documents=iter(sciq.corpus),
|
332 |
+
ndocs=12160,
|
333 |
+
show_progress_bar=True
|
334 |
+
)
|
335 |
+
bm25_index.save("output/bm25_index")
|
336 |
+
|
337 |
+
# Loading index and use BM25 retriever to retrieve:
|
338 |
+
bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
|
339 |
+
print(bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?")) # the ranking
|
340 |
+
|
341 |
+
plots_b: Dict[str, List[float]] = {
|
342 |
+
"X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
|
343 |
+
"Y": []
|
344 |
+
}
|
345 |
+
plots_k1: Dict[str, List[float]] = {
|
346 |
+
"X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
|
347 |
+
"Y": []
|
348 |
+
}
|
349 |
+
|
350 |
+
## YOUR_CODE_STARTS_HERE
|
351 |
+
# Two steps should be involved:
|
352 |
+
# Step 1. Fix k1 value to the default one 0.9,
|
353 |
+
# go through all the candidate b values (0, 0.1, ..., 1.0),
|
354 |
+
# and record in plots_b["Y"] the corresponding performances obtained via evaluate_map;
|
355 |
+
# Step 2. Fix b to the best one in step 1. and do the same for k1.
|
356 |
+
|
357 |
+
# Hint (on using the pre-requisite code):
|
358 |
+
# - One can use the loaded sciq dataset directly (loaded in the pre-requisite code);
|
359 |
+
# - One can build bm25_index with `BM25Index.build_from_documents`;
|
360 |
+
# - One can use BM25Retriever to load the index and perform retrieval on the dev queries
|
361 |
+
# (dev queries can be obtained via sciq.get_split_queries(Split.dev))
|
362 |
+
|
363 |
+
|
364 |
+
k1 = 0.9
|
365 |
+
b_list = []
|
366 |
+
for b in plots_b["X"]:
|
367 |
+
bm25_index = BM25Index.build_from_documents(
|
368 |
+
documents=iter(sciq.corpus),
|
369 |
+
ndocs=12160,
|
370 |
+
show_progress_bar=True,
|
371 |
+
k1=k1,
|
372 |
+
b=b
|
373 |
+
)
|
374 |
+
bm25_index.save("output/bm25_index")
|
375 |
+
bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
|
376 |
+
rankings = {}
|
377 |
+
for query in sciq.get_split_queries(Split.dev):
|
378 |
+
ranking = bm25_retriever.retrieve(query=query.text)
|
379 |
+
rankings[query.query_id] = ranking
|
380 |
+
optimized_map = evaluate_map(rankings, split=Split.dev)
|
381 |
+
b_list.append(optimized_map)
|
382 |
+
|
383 |
+
plots_b["Y"] = b_list
|
384 |
+
|
385 |
+
b = plots_b["X"][np.argmax(plots_b["Y"])]
|
386 |
+
k1_list = []
|
387 |
+
for k1 in plots_k1["X"]:
|
388 |
+
bm25_index = BM25Index.build_from_documents(
|
389 |
+
documents=iter(sciq.corpus),
|
390 |
+
ndocs=12160,
|
391 |
+
show_progress_bar=True,
|
392 |
+
k1=k1,
|
393 |
+
b=b
|
394 |
+
)
|
395 |
+
bm25_index.save("output/bm25_index")
|
396 |
+
bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
|
397 |
+
rankings = {}
|
398 |
+
for query in sciq.get_split_queries(Split.dev):
|
399 |
+
ranking = bm25_retriever.retrieve(query=query.text)
|
400 |
+
rankings[query.query_id] = ranking
|
401 |
+
optimized_map = evaluate_map(rankings, split=Split.dev)
|
402 |
+
k1_list.append(optimized_map)
|
403 |
+
|
404 |
+
plots_k1["Y"] = k1_list
|
405 |
+
|
406 |
+
|
407 |
+
|
408 |
+
"""Let's check the effectiveness gain on test after this tuning on dev"""
|
409 |
+
|
410 |
+
default_map = 0.7849
|
411 |
+
best_b = plots_b["X"][np.argmax(plots_b["Y"])]
|
412 |
+
best_k1 = plots_k1["X"][np.argmax(plots_k1["Y"])]
|
413 |
+
bm25_index = BM25Index.build_from_documents(
|
414 |
+
documents=iter(sciq.corpus),
|
415 |
+
ndocs=12160,
|
416 |
+
show_progress_bar=True,
|
417 |
+
k1=best_k1,
|
418 |
+
b=best_b
|
419 |
+
)
|
420 |
+
bm25_index.save("output/bm25_index")
|
421 |
+
bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
|
422 |
+
rankings = {}
|
423 |
+
for query in sciq.get_split_queries(Split.test): # note this is now on test
|
424 |
+
ranking = bm25_retriever.retrieve(query=query.text)
|
425 |
+
rankings[query.query_id] = ranking
|
426 |
+
optimized_map = evaluate_map(rankings, split=Split.test) # note this is now on test
|
427 |
+
print(default_map, optimized_map)
|
428 |
+
|
429 |
+
"""## TASK2.2: implement `CSCBM25Index` (4 points)
|
430 |
+
|
431 |
+
Implement `CSCBM25Index` by completing the missing code. Note that `CSCInvertedIndex` is similar to `InvertedIndex` which we talked about during the class. The main difference is posting lists are represented by a CSC sparse matrix.
|
432 |
+
"""
|
433 |
+
|
434 |
+
@dataclass
|
435 |
+
class CSCInvertedIndex:
|
436 |
+
posting_lists_matrix: csc_matrix # docid -> posting_list
|
437 |
+
vocab: Dict[str, int]
|
438 |
+
cid2docid: Dict[str, int] # collection_id -> docid
|
439 |
+
collection_ids: List[str] # docid -> collection_id
|
440 |
+
doc_texts: Optional[List[str]] = None # docid -> document text
|
441 |
+
|
442 |
+
def save(self, output_dir: str) -> None:
|
443 |
+
os.makedirs(output_dir, exist_ok=True)
|
444 |
+
with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
|
445 |
+
pickle.dump(self, f)
|
446 |
+
|
447 |
+
@classmethod
|
448 |
+
def from_saved(cls: Type[T], saved_dir: str) -> T:
|
449 |
+
index = cls(
|
450 |
+
posting_lists_matrix=None, vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
|
451 |
+
)
|
452 |
+
with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
|
453 |
+
index = pickle.load(f)
|
454 |
+
return index
|
455 |
+
|
456 |
+
@dataclass
|
457 |
+
class CSCBM25Index(CSCInvertedIndex):
|
458 |
+
|
459 |
+
@staticmethod
|
460 |
+
def tokenize(text: str) -> List[str]:
|
461 |
+
return simple_tokenize(text)
|
462 |
+
|
463 |
+
@staticmethod
|
464 |
+
def cache_term_weights(
|
465 |
+
posting_lists: List[PostingList],
|
466 |
+
total_docs: int,
|
467 |
+
avgdl: float,
|
468 |
+
dfs: List[int],
|
469 |
+
dls: List[int],
|
470 |
+
k1: float,
|
471 |
+
b: float,
|
472 |
+
) -> csc_matrix:
|
473 |
+
"""Compute term weights and caching"""
|
474 |
+
|
475 |
+
## YOUR_CODE_STARTS_HERE
|
476 |
+
data = []
|
477 |
+
indices = []
|
478 |
+
indptr = []
|
479 |
+
|
480 |
+
N = total_docs
|
481 |
+
for tid, posting_list in enumerate(
|
482 |
+
tqdm.tqdm(posting_lists, desc="Regularizing TFs")
|
483 |
+
):
|
484 |
+
if indptr == []:
|
485 |
+
indptr.append(0)
|
486 |
+
#if dfs[tid] != len(posting_list.docid_postings):
|
487 |
+
# print(dfs[tid], ", ", len(posting_list.docid_postings))
|
488 |
+
#if dfs[tid] == 0:
|
489 |
+
# print(posting_list.docid_postings[0])
|
490 |
+
indptr.append(indptr[-1] + len(posting_list.docid_postings))
|
491 |
+
idf = CSCBM25Index.calc_idf(df=dfs[tid], N=N)
|
492 |
+
for i in range(len(posting_list.docid_postings)):
|
493 |
+
docid = posting_list.docid_postings[i]
|
494 |
+
indices.append(docid)
|
495 |
+
|
496 |
+
tf = posting_list.tweight_postings[i]
|
497 |
+
dl = dls[docid]
|
498 |
+
regularized_tf = CSCBM25Index.calc_regularized_tf(
|
499 |
+
tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
|
500 |
+
)
|
501 |
+
|
502 |
+
tf_idf = regularized_tf * idf
|
503 |
+
data.append(tf_idf)
|
504 |
+
|
505 |
+
posting_lists_matrix = csc_matrix((data, indices, indptr)).astype(np.float32)
|
506 |
+
print(posting_lists_matrix.shape)
|
507 |
+
return posting_lists_matrix
|
508 |
+
## YOUR_CODE_ENDS_HERE
|
509 |
+
|
510 |
+
@staticmethod
|
511 |
+
def calc_regularized_tf(
|
512 |
+
tf: int, dl: float, avgdl: float, k1: float, b: float
|
513 |
+
) -> float:
|
514 |
+
return tf / (tf + k1 * (1 - b + b * dl / avgdl))
|
515 |
+
|
516 |
+
@staticmethod
|
517 |
+
def calc_idf(df: int, N: int):
|
518 |
+
return math.log(1 + (N - df + 0.5) / (df + 0.5))
|
519 |
+
|
520 |
+
@classmethod
|
521 |
+
def build_from_documents(
|
522 |
+
cls: Type[CSCBM25Index],
|
523 |
+
documents: Iterable[Document],
|
524 |
+
store_raw: bool = True,
|
525 |
+
output_dir: Optional[str] = None,
|
526 |
+
ndocs: Optional[int] = None,
|
527 |
+
show_progress_bar: bool = True,
|
528 |
+
k1: float = 0.9,
|
529 |
+
b: float = 0.4,
|
530 |
+
) -> CSCBM25Index:
|
531 |
+
# Counting TFs, DFs, doc_lengths, etc.:
|
532 |
+
counting = run_counting(
|
533 |
+
documents=documents,
|
534 |
+
tokenize_fn=CSCBM25Index.tokenize,
|
535 |
+
store_raw=store_raw,
|
536 |
+
ndocs=ndocs,
|
537 |
+
show_progress_bar=show_progress_bar,
|
538 |
+
)
|
539 |
+
|
540 |
+
# Compute term weights and caching:
|
541 |
+
posting_lists = counting.posting_lists
|
542 |
+
total_docs = len(counting.cid2docid)
|
543 |
+
posting_lists_matrix = CSCBM25Index.cache_term_weights(
|
544 |
+
posting_lists=posting_lists,
|
545 |
+
total_docs=total_docs,
|
546 |
+
avgdl=counting.avgdl,
|
547 |
+
dfs=counting.dfs,
|
548 |
+
dls=counting.dls,
|
549 |
+
k1=k1,
|
550 |
+
b=b,
|
551 |
+
)
|
552 |
+
|
553 |
+
# Assembly and save:
|
554 |
+
index = CSCBM25Index(
|
555 |
+
posting_lists_matrix=posting_lists_matrix,
|
556 |
+
vocab=counting.vocab,
|
557 |
+
cid2docid=counting.cid2docid,
|
558 |
+
collection_ids=counting.collection_ids,
|
559 |
+
doc_texts=counting.doc_texts,
|
560 |
+
)
|
561 |
+
return index
|
562 |
+
|
563 |
+
csc_bm25_index = CSCBM25Index.build_from_documents(
|
564 |
+
documents=iter(sciq.corpus),
|
565 |
+
ndocs=12160,
|
566 |
+
show_progress_bar=True,
|
567 |
+
k1=best_k1,
|
568 |
+
b=best_b
|
569 |
+
)
|
570 |
+
csc_bm25_index.save("output/csc_bm25_index")
|
571 |
+
|
572 |
+
class BaseCSCInvertedIndexRetriever(BaseRetriever):
|
573 |
+
|
574 |
+
@property
|
575 |
+
@abstractmethod
|
576 |
+
def index_class(self) -> Type[CSCInvertedIndex]:
|
577 |
+
pass
|
578 |
+
|
579 |
+
def __init__(self, index_dir: str) -> None:
|
580 |
+
self.index = self.index_class.from_saved(index_dir)
|
581 |
+
|
582 |
+
def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
|
583 |
+
## YOUR_CODE_STARTS_HERE
|
584 |
+
toks = CSCBM25Index.tokenize(query)
|
585 |
+
target_docid = self.index.cid2docid[cid]
|
586 |
+
term_weights = {}
|
587 |
+
for tok in toks:
|
588 |
+
if tok not in self.index.vocab:
|
589 |
+
continue
|
590 |
+
tid = self.index.vocab[tok]
|
591 |
+
weight = self.index.posting_lists_matrix[target_docid, tid]
|
592 |
+
if weight == 0:
|
593 |
+
continue
|
594 |
+
term_weights[tok] = weight
|
595 |
+
return term_weights
|
596 |
+
## YOUR_CODE_ENDS_HERE
|
597 |
+
|
598 |
+
def score(self, query: str, cid: str) -> float:
|
599 |
+
return sum(self.get_term_weights(query=query, cid=cid).values())
|
600 |
+
|
601 |
+
def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
|
602 |
+
## YOUR_CODE_STARTS_HERE
|
603 |
+
toks = CSCBM25Index.tokenize(query)
|
604 |
+
docid2score: Dict[int, float] = {}
|
605 |
+
for tok in toks:
|
606 |
+
if tok not in self.index.vocab:
|
607 |
+
continue
|
608 |
+
tid = self.index.vocab[tok]
|
609 |
+
posting_list = self.index.posting_lists_matrix.getcol(tid)
|
610 |
+
indices = posting_list.indices
|
611 |
+
weights = posting_list.data
|
612 |
+
for docid, tweight in zip(indices, weights):
|
613 |
+
docid2score.setdefault(docid, 0)
|
614 |
+
docid2score[docid] += tweight
|
615 |
+
|
616 |
+
|
617 |
+
docid2score = dict(
|
618 |
+
sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk]
|
619 |
+
)
|
620 |
+
ranking = {
|
621 |
+
self.index.collection_ids[docid]: score
|
622 |
+
for docid, score in docid2score.items()
|
623 |
+
}
|
624 |
+
return ranking
|
625 |
+
## YOUR_CODE_ENDS_HERE
|
626 |
+
|
627 |
+
|
628 |
+
class CSCBM25Retriever(BaseCSCInvertedIndexRetriever):
|
629 |
+
|
630 |
+
@property
|
631 |
+
def index_class(self) -> Type[CSCBM25Index]:
|
632 |
+
return CSCBM25Index
|
633 |
+
|
634 |
+
|
635 |
class Hit(TypedDict):
|
636 |
cid: str
|
637 |
score: float
|
|
|
641 |
return_type = List[Hit]
|
642 |
|
643 |
## YOUR_CODE_STARTS_HERE
|
644 |
+
csc_bm25_index = CSCBM25Index.build_from_documents(
|
645 |
documents=iter(sciq.corpus),
|
646 |
ndocs=12160,
|
647 |
+
show_progress_bar=True,
|
648 |
+
k1=best_k1,
|
649 |
+
b=best_b
|
650 |
)
|
651 |
+
csc_bm25_index.save("output/csc_bm25_index")
|
652 |
|
653 |
def search(query: str) -> List[Hit]:
|
654 |
+
csc_bm25_index = CSCBM25Retriever(index_dir="output/csc_bm25_index")
|
655 |
+
result = csc_bm25_index.retrieve(query)
|
656 |
|
657 |
l : return_type = []
|
658 |
for cid, score in result.items():
|
659 |
+
docid = csc_bm25_index.index.cid2docid[cid]
|
660 |
+
text = csc_bm25_index.index.doc_texts[docid]
|
661 |
|
662 |
l.append(Hit(cid=cid, score=score, text=text))
|
663 |
|