Graduation / pipelines /bm25s /tests /core /test_save_load.py
DuyTa's picture
Upload folder using huggingface_hub
74b1bac verified
raw
history blame
3.29 kB
import os
import shutil
from pathlib import Path
import unittest
import tempfile
import bm25s
import Stemmer # optional: for stemming
class TestBM25SLoadingSaving(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Create your corpus here
corpus = [
"a cat is a feline and likes to purr",
"a dog is the human's best friend and loves to play",
"a bird is a beautiful animal that can fly",
"a fish is a creature that lives in water and swims",
]
# optional: create a stemmer
stemmer = Stemmer.Stemmer("english")
# Tokenize the corpus and only keep the ids (faster and saves memory)
corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer)
# Create the BM25 model and index the corpus
retriever = bm25s.BM25(method='bm25+')
retriever.index(corpus_tokens)
# Save the retriever to temp dir
cls.retriever = retriever
cls.corpus = corpus
cls.corpus_tokens = corpus_tokens
cls.stemmer = stemmer
cls.tmpdirname = tempfile.mkdtemp()
def test_a_save(self):
# save the retriever to temp dir
self.retriever.save(
self.tmpdirname,
data_name="data.index.csc.npy",
indices_name="indices.index.csc.npy",
indptr_name="indptr.index.csc.npy",
vocab_name="vocab.json",
nnoc_name="nonoccurrence_array.npy",
params_name="params.json",
)
# assert that the following files are saved
fnames = [
"data.index.csc.npy",
"indices.index.csc.npy",
"indptr.index.csc.npy",
"vocab.json",
"nonoccurrence_array.npy",
"params.json",
]
for fname in fnames:
error_msg = f"File {fname} not found in even though it should be saved by the .save() method"
path_exists = os.path.exists(os.path.join(self.tmpdirname, fname))
self.assertTrue(path_exists, error_msg)
def test_b_load(self):
# load the retriever from temp dir
r1 = self.retriever
r2 = bm25s.BM25.load(
self.tmpdirname,
data_name="data.index.csc.npy",
indices_name="indices.index.csc.npy",
indptr_name="indptr.index.csc.npy",
vocab_name="vocab.json",
nnoc_name="nonoccurrence_array.npy",
params_name="params.json",
)
# for each of data, indices, indptr, vocab, nnoc, params
# assert that the loaded object is the same as the original object
# data, indices, indptr are stored in self.scores
self.assertTrue((r1.scores['data'] == r2.scores['data']).all())
self.assertTrue((r1.scores['indices'] == r2.scores['indices']).all())
self.assertTrue((r1.scores['indptr'] == r2.scores['indptr']).all())
# vocab is stored in self.vocab
self.assertEqual(r1.vocab_dict, r2.vocab_dict)
# nnoc is stored in self.nnoc
self.assertTrue((r1.nonoccurrence_array == r2.nonoccurrence_array).all())
@classmethod
def tearDownClass(cls):
# remove the temp dir with rmtree
shutil.rmtree(cls.tmpdirname)