advisor / test_embeddings.py
veerukhannan's picture
Update test_embeddings.py
e55cfd3 verified
import os
import chromadb
from sentence_transformers import SentenceTransformer
from loguru import logger
class SentenceTransformerEmbeddings:
def __init__(self, model_name: str = 'all-MiniLM-L6-v2'):
self.model = SentenceTransformer(model_name)
def __call__(self, input: list[str]) -> list[list[float]]:
embeddings = self.model.encode(input)
return embeddings.tolist()
def initialize_chromadb():
"""Initialize ChromaDB and load documents if needed"""
try:
# Set up paths
base_path = os.path.dirname(os.path.abspath(__file__))
doc_path = os.path.join(base_path, 'a2023-45.txt')
index_path = os.path.join(base_path, 'index.txt')
chroma_path = os.path.join(base_path, 'chroma_db')
# Check if required files exist
if not os.path.exists(doc_path):
logger.error(f"Document file not found at {doc_path}")
return False
if not os.path.exists(index_path):
logger.error(f"Index file not found at {index_path}")
return False
# Ensure ChromaDB directory exists
os.makedirs(chroma_path, exist_ok=True)
# Initialize ChromaDB
chroma_client = chromadb.PersistentClient(path=chroma_path)
embedding_function = SentenceTransformerEmbeddings()
# Check if collection exists and has content
collections = chroma_client.list_collections()
collection_exists = any(col.name == "legal_documents" for col in collections)
if collection_exists:
collection = chroma_client.get_collection(
name="legal_documents",
embedding_function=embedding_function
)
if collection.count() > 0:
logger.info("ChromaDB collection already exists and has content")
return True
# If we get here, we need to create or repopulate the collection
logger.info("Loading documents into ChromaDB...")
# Delete existing collection if it exists
if collection_exists:
chroma_client.delete_collection("legal_documents")
# Create new collection
collection = chroma_client.create_collection(
name="legal_documents",
embedding_function=embedding_function
)
# Read and process documents
with open(doc_path, 'r', encoding='utf-8') as f:
document = f.read().strip()
with open(index_path, 'r', encoding='utf-8') as f:
index_content = [line.strip() for line in f.readlines() if line.strip()]
# Process document into sections
sections = []
current_section = ""
current_title = ""
for line in document.split('\n'):
line = line.strip()
if any(index_line in line for index_line in index_content):
if current_section and current_title:
sections.append({
"title": current_title,
"content": current_section.strip()
})
current_title = line
current_section = ""
else:
if line:
current_section += line + "\n"
if current_section and current_title:
sections.append({
"title": current_title,
"content": current_section.strip()
})
# Prepare and add data to ChromaDB
if sections:
documents = []
metadatas = []
ids = []
for i, section in enumerate(sections):
if section["content"].strip():
documents.append(section["content"])
metadatas.append({
"title": section["title"],
"source": "a2023-45.txt",
"section_number": i + 1
})
ids.append(f"section_{i+1}")
collection.add(
documents=documents,
metadatas=metadatas,
ids=ids
)
logger.info(f"Successfully loaded {len(documents)} sections into ChromaDB")
return True
else:
logger.error("No valid sections found in document")
return False
except Exception as e:
logger.error(f"Error initializing ChromaDB: {str(e)}")
return False
def test_chromadb_content():
"""Test if ChromaDB has the required content"""
try:
# First ensure ChromaDB is initialized
if not initialize_chromadb():
return False
# Set up ChromaDB path
base_path = os.path.dirname(os.path.abspath(__file__))
chroma_path = os.path.join(base_path, 'chroma_db')
# Initialize ChromaDB
chroma_client = chromadb.PersistentClient(path=chroma_path)
# Get collection
collection = chroma_client.get_collection(
name="legal_documents",
embedding_function=SentenceTransformerEmbeddings()
)
# Check collection size
count = collection.count()
if count == 0:
logger.error("Collection is empty")
return False
logger.info(f"Found {count} documents in ChromaDB")
# Test query to verify content
test_results = collection.query(
query_texts=["What are the general provisions?"],
n_results=1
)
if not test_results['documents']:
logger.error("Test query returned no results")
return False
logger.info("ChromaDB content verification successful")
return True
except Exception as e:
logger.error(f"Error testing ChromaDB: {str(e)}")
return False
if __name__ == "__main__":
success = test_chromadb_content()
if success:
print("ChromaDB content verification successful")
else:
print("ChromaDB content verification failed")