geekyrakshit commited on
Commit
7f98acf
·
1 Parent(s): 88a5bcf

update: BM25sRetriever

Browse files
medrag_multi_modal/retrieval/bm25s_retrieval.py CHANGED
@@ -29,9 +29,9 @@ class BM25sRetriever(weave.Model):
29
  super().__init__(language=language, use_stemmer=use_stemmer)
30
  self._retriever = retriever or bm25s.BM25()
31
 
32
- def index(self, corpus_dataset_name: str, index_name: Optional[str] = None):
33
- corpus_dataset = weave.ref(corpus_dataset_name).get().rows
34
- corpus = [row["text"] for row in corpus_dataset]
35
  corpus_tokens = bm25s.tokenize(
36
  corpus,
37
  stopwords=LANGUAGE_DICT[self.language],
@@ -40,7 +40,7 @@ class BM25sRetriever(weave.Model):
40
  self._retriever.index(corpus_tokens)
41
  if index_name:
42
  self._retriever.save(
43
- index_name, corpus=[dict(row) for row in corpus_dataset]
44
  )
45
  if wandb.run:
46
  artifact = wandb.Artifact(
@@ -81,8 +81,8 @@ class BM25sRetriever(weave.Model):
81
  stopwords=LANGUAGE_DICT[self.language],
82
  stemmer=Stemmer(self.language) if self.use_stemmer else None,
83
  )
84
- results, scores = self._retriever.retrieve(query_tokens, k=top_k)
85
  return {
86
- "results": results,
87
- "scores": scores,
88
  }
 
29
  super().__init__(language=language, use_stemmer=use_stemmer)
30
  self._retriever = retriever or bm25s.BM25()
31
 
32
+ def index(self, chunk_dataset_name: str, index_name: Optional[str] = None):
33
+ chunk_dataset = weave.ref(chunk_dataset_name).get().rows
34
+ corpus = [row["text"] for row in chunk_dataset]
35
  corpus_tokens = bm25s.tokenize(
36
  corpus,
37
  stopwords=LANGUAGE_DICT[self.language],
 
40
  self._retriever.index(corpus_tokens)
41
  if index_name:
42
  self._retriever.save(
43
+ index_name, corpus=[dict(row) for row in chunk_dataset]
44
  )
45
  if wandb.run:
46
  artifact = wandb.Artifact(
 
81
  stopwords=LANGUAGE_DICT[self.language],
82
  stemmer=Stemmer(self.language) if self.use_stemmer else None,
83
  )
84
+ results = self._retriever.retrieve(query_tokens, k=top_k)
85
  return {
86
+ "results": results.documents,
87
+ "scores": results.scores,
88
  }