Spaces:
Sleeping
Sleeping
File size: 3,004 Bytes
8ff45d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import unittest
from rag import load_faiss_index, generate_answer
from langchain.retrievers import MergerRetriever
class TestRetrieval(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Initialize FAISS indices
cls.data_vec = load_faiss_index("./vectors_data/msd_data_vec", "sentence-transformers/all-MiniLM-L12-v2")
cls.med_vec = load_faiss_index("./vectors_data/med_data_vec", "sentence-transformers/all-MiniLM-L12-v2")
# Initialize retrievers
cls.data_retriever = cls.data_vec.as_retriever()
cls.med_retriever = cls.med_vec.as_retriever()
# Combine both retrievers into a single retriever
cls.combined_retriever = MergerRetriever(retrievers=[cls.data_retriever, cls.med_retriever])
def test_data_retriever(self):
# Test the data retriever with a specific query
query = "what are the symptoms of diabetes?"
docs = self.data_retriever.get_relevant_documents(query)
# Assert that documents are returned and are not empty
self.assertIsNotNone(docs)
self.assertTrue(len(docs) > 0)
# Check if documents have content
self.assertTrue(all(doc.page_content.strip() != "" for doc in docs))
def test_med_retriever(self):
# Test the medical retriever with a specific query
query = "what are common antibiotics?"
docs = self.med_retriever.get_relevant_documents(query)
# Assert that documents are returned and are not empty
self.assertIsNotNone(docs)
self.assertTrue(len(docs) > 0)
self.assertTrue(all(doc.page_content.strip() != "" for doc in docs))
def test_combined_retriever(self):
# Test the combined retriever with a specific query
query = "what is the treatment for high blood pressure?"
docs = self.combined_retriever.get_relevant_documents(query)
# Assert that documents are returned and are not empty
self.assertIsNotNone(docs)
self.assertTrue(len(docs) > 0)
self.assertTrue(all(doc.page_content.strip() != "" for doc in docs))
def test_generate_answer(self):
# Test the answer generation function with a specific query
query = "what are the side effects of aspirin?"
response = generate_answer(query)
# Assert that a valid response is returned
self.assertIsNotNone(response)
self.assertIsInstance(response, str)
self.assertTrue(len(response) > 0)
def test_empty_query(self):
# Test the answer generation function with an empty query
with self.assertRaises(ValueError): # More specific exception
generate_answer("")
# Test the answer generation function with a whitespace-only query
with self.assertRaises(ValueError):
generate_answer(" ") # Test whitespace-only query
if __name__ == '__main__':
unittest.main()
|