Spaces:
Sleeping
Sleeping
Commit
·
7f98acf
1
Parent(s):
88a5bcf
update: BM25sRetriever
Browse files
medrag_multi_modal/retrieval/bm25s_retrieval.py
CHANGED
@@ -29,9 +29,9 @@ class BM25sRetriever(weave.Model):
|
|
29 |
super().__init__(language=language, use_stemmer=use_stemmer)
|
30 |
self._retriever = retriever or bm25s.BM25()
|
31 |
|
32 |
-
def index(self,
|
33 |
-
|
34 |
-
corpus = [row["text"] for row in
|
35 |
corpus_tokens = bm25s.tokenize(
|
36 |
corpus,
|
37 |
stopwords=LANGUAGE_DICT[self.language],
|
@@ -40,7 +40,7 @@ class BM25sRetriever(weave.Model):
|
|
40 |
self._retriever.index(corpus_tokens)
|
41 |
if index_name:
|
42 |
self._retriever.save(
|
43 |
-
index_name, corpus=[dict(row) for row in
|
44 |
)
|
45 |
if wandb.run:
|
46 |
artifact = wandb.Artifact(
|
@@ -81,8 +81,8 @@ class BM25sRetriever(weave.Model):
|
|
81 |
stopwords=LANGUAGE_DICT[self.language],
|
82 |
stemmer=Stemmer(self.language) if self.use_stemmer else None,
|
83 |
)
|
84 |
-
results
|
85 |
return {
|
86 |
-
"results": results,
|
87 |
-
"scores": scores,
|
88 |
}
|
|
|
29 |
super().__init__(language=language, use_stemmer=use_stemmer)
|
30 |
self._retriever = retriever or bm25s.BM25()
|
31 |
|
32 |
+
def index(self, chunk_dataset_name: str, index_name: Optional[str] = None):
|
33 |
+
chunk_dataset = weave.ref(chunk_dataset_name).get().rows
|
34 |
+
corpus = [row["text"] for row in chunk_dataset]
|
35 |
corpus_tokens = bm25s.tokenize(
|
36 |
corpus,
|
37 |
stopwords=LANGUAGE_DICT[self.language],
|
|
|
40 |
self._retriever.index(corpus_tokens)
|
41 |
if index_name:
|
42 |
self._retriever.save(
|
43 |
+
index_name, corpus=[dict(row) for row in chunk_dataset]
|
44 |
)
|
45 |
if wandb.run:
|
46 |
artifact = wandb.Artifact(
|
|
|
81 |
stopwords=LANGUAGE_DICT[self.language],
|
82 |
stemmer=Stemmer(self.language) if self.use_stemmer else None,
|
83 |
)
|
84 |
+
results = self._retriever.retrieve(query_tokens, k=top_k)
|
85 |
return {
|
86 |
+
"results": results.documents,
|
87 |
+
"scores": results.scores,
|
88 |
}
|