Spaces:
Sleeping
Sleeping
import unittest | |
from rag import load_faiss_index, generate_answer | |
from langchain.retrievers import MergerRetriever | |
class TestRetrieval(unittest.TestCase): | |
def setUpClass(cls): | |
# Initialize FAISS indices | |
cls.data_vec = load_faiss_index("./vectors_data/msd_data_vec", "sentence-transformers/all-MiniLM-L12-v2") | |
cls.med_vec = load_faiss_index("./vectors_data/med_data_vec", "sentence-transformers/all-MiniLM-L12-v2") | |
# Initialize retrievers | |
cls.data_retriever = cls.data_vec.as_retriever() | |
cls.med_retriever = cls.med_vec.as_retriever() | |
# Combine both retrievers into a single retriever | |
cls.combined_retriever = MergerRetriever(retrievers=[cls.data_retriever, cls.med_retriever]) | |
def test_data_retriever(self): | |
# Test the data retriever with a specific query | |
query = "what are the symptoms of diabetes?" | |
docs = self.data_retriever.get_relevant_documents(query) | |
# Assert that documents are returned and are not empty | |
self.assertIsNotNone(docs) | |
self.assertTrue(len(docs) > 0) | |
# Check if documents have content | |
self.assertTrue(all(doc.page_content.strip() != "" for doc in docs)) | |
def test_med_retriever(self): | |
# Test the medical retriever with a specific query | |
query = "what are common antibiotics?" | |
docs = self.med_retriever.get_relevant_documents(query) | |
# Assert that documents are returned and are not empty | |
self.assertIsNotNone(docs) | |
self.assertTrue(len(docs) > 0) | |
self.assertTrue(all(doc.page_content.strip() != "" for doc in docs)) | |
def test_combined_retriever(self): | |
# Test the combined retriever with a specific query | |
query = "what is the treatment for high blood pressure?" | |
docs = self.combined_retriever.get_relevant_documents(query) | |
# Assert that documents are returned and are not empty | |
self.assertIsNotNone(docs) | |
self.assertTrue(len(docs) > 0) | |
self.assertTrue(all(doc.page_content.strip() != "" for doc in docs)) | |
def test_generate_answer(self): | |
# Test the answer generation function with a specific query | |
query = "what are the side effects of aspirin?" | |
response = generate_answer(query) | |
# Assert that a valid response is returned | |
self.assertIsNotNone(response) | |
self.assertIsInstance(response, str) | |
self.assertTrue(len(response) > 0) | |
def test_empty_query(self): | |
# Test the answer generation function with an empty query | |
with self.assertRaises(ValueError): # More specific exception | |
generate_answer("") | |
# Test the answer generation function with a whitespace-only query | |
with self.assertRaises(ValueError): | |
generate_answer(" ") # Test whitespace-only query | |
if __name__ == '__main__': | |
unittest.main() | |