artech-med-bot / test_rag.py
shamim237's picture
initial commit
8ff45d7 verified
import unittest
from rag import load_faiss_index, generate_answer
from langchain.retrievers import MergerRetriever
class TestRetrieval(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Initialize FAISS indices
cls.data_vec = load_faiss_index("./vectors_data/msd_data_vec", "sentence-transformers/all-MiniLM-L12-v2")
cls.med_vec = load_faiss_index("./vectors_data/med_data_vec", "sentence-transformers/all-MiniLM-L12-v2")
# Initialize retrievers
cls.data_retriever = cls.data_vec.as_retriever()
cls.med_retriever = cls.med_vec.as_retriever()
# Combine both retrievers into a single retriever
cls.combined_retriever = MergerRetriever(retrievers=[cls.data_retriever, cls.med_retriever])
def test_data_retriever(self):
# Test the data retriever with a specific query
query = "what are the symptoms of diabetes?"
docs = self.data_retriever.get_relevant_documents(query)
# Assert that documents are returned and are not empty
self.assertIsNotNone(docs)
self.assertTrue(len(docs) > 0)
# Check if documents have content
self.assertTrue(all(doc.page_content.strip() != "" for doc in docs))
def test_med_retriever(self):
# Test the medical retriever with a specific query
query = "what are common antibiotics?"
docs = self.med_retriever.get_relevant_documents(query)
# Assert that documents are returned and are not empty
self.assertIsNotNone(docs)
self.assertTrue(len(docs) > 0)
self.assertTrue(all(doc.page_content.strip() != "" for doc in docs))
def test_combined_retriever(self):
# Test the combined retriever with a specific query
query = "what is the treatment for high blood pressure?"
docs = self.combined_retriever.get_relevant_documents(query)
# Assert that documents are returned and are not empty
self.assertIsNotNone(docs)
self.assertTrue(len(docs) > 0)
self.assertTrue(all(doc.page_content.strip() != "" for doc in docs))
def test_generate_answer(self):
# Test the answer generation function with a specific query
query = "what are the side effects of aspirin?"
response = generate_answer(query)
# Assert that a valid response is returned
self.assertIsNotNone(response)
self.assertIsInstance(response, str)
self.assertTrue(len(response) > 0)
def test_empty_query(self):
# Test the answer generation function with an empty query
with self.assertRaises(ValueError): # More specific exception
generate_answer("")
# Test the answer generation function with a whitespace-only query
with self.assertRaises(ValueError):
generate_answer(" ") # Test whitespace-only query
if __name__ == '__main__':
unittest.main()