|
import os |
|
import shutil |
|
from pathlib import Path |
|
import unittest |
|
import tempfile |
|
|
|
import numpy as np |
|
import bm25s |
|
import Stemmer |
|
|
|
class TestBM25SLoadingSaving(unittest.TestCase): |
|
@classmethod |
|
def setUpClass(cls): |
|
|
|
|
|
corpus = [ |
|
"a cat is a feline and likes to purr", |
|
"a dog is the human's best friend and loves to play", |
|
"a bird is a beautiful animal that can fly", |
|
"a fish is a creature that lives in water and swims", |
|
] |
|
|
|
|
|
stemmer = Stemmer.Stemmer("english") |
|
|
|
|
|
corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer) |
|
|
|
|
|
retriever = bm25s.BM25(method='bm25+') |
|
retriever.index(corpus_tokens) |
|
|
|
|
|
cls.retriever = retriever |
|
cls.corpus = corpus |
|
cls.corpus_tokens = corpus_tokens |
|
cls.stemmer = stemmer |
|
|
|
def test_retrieve(self): |
|
ground_truth = np.array([[0, 2]]) |
|
|
|
|
|
query = "a cat is a feline, it's sometimes beautiful but cannot fly" |
|
query_tokens_obj = bm25s.tokenize([query], stopwords="en", stemmer=self.stemmer, return_ids=True) |
|
|
|
|
|
results = self.retriever.retrieve(query_tokens_obj, k=2).documents |
|
|
|
|
|
self.assertTrue(np.array_equal(ground_truth, results), f"Expected {ground_truth}, got {results}") |
|
|
|
|
|
query_tokens_texts = bm25s.tokenize([query], stopwords="en", stemmer=self.stemmer, return_ids=False) |
|
results = self.retriever.retrieve(query_tokens_texts, k=2).documents |
|
self.assertTrue(np.array_equal(ground_truth, results), f"Expected {ground_truth}, got {results}") |
|
|
|
|
|
ids, vocab = query_tokens_obj |
|
query_tokens_tuple = (ids, vocab) |
|
results = self.retriever.retrieve(query_tokens_tuple, k=2).documents |
|
self.assertTrue(np.array_equal(ground_truth, results), f"Expected {ground_truth}, got {results}") |
|
|
|
|
|
queries_as_tuple = (query_tokens_texts[0], query_tokens_texts[0]) |
|
|
|
ground_truth = np.array([[0], [0]]) |
|
results = self.retriever.retrieve(queries_as_tuple, k=1).documents |
|
self.assertTrue(np.array_equal(ground_truth, results), f"Expected {ground_truth}, got {results}") |
|
|
|
def test_failure_of_bad_tuple(self): |
|
|
|
query = "a cat is a feline, it's sometimes beautiful but cannot fly" |
|
query_tokens_obj = bm25s.tokenize([query], stopwords="en", stemmer=self.stemmer, return_ids=True) |
|
query_tokens_texts = bm25s.tokenize([query], stopwords="en", stemmer=self.stemmer, return_ids=False) |
|
ids, vocab = query_tokens_obj |
|
query_tokens_tuple = (vocab, ids) |
|
|
|
with self.assertRaises(ValueError): |
|
self.retriever.retrieve(query_tokens_tuple, k=2) |
|
|
|
|
|
query_tokens_tuple = (ids, ids) |
|
with self.assertRaises(ValueError): |
|
self.retriever.retrieve(query_tokens_tuple, k=2) |
|
|
|
|
|
query_tokens_tuple = (vocab, ) |
|
with self.assertRaises(ValueError): |
|
self.retriever.retrieve(query_tokens_tuple, k=2) |
|
|
|
|
|
|
|
|
|
|
|
@classmethod |
|
def tearDownClass(cls): |
|
pass |