advisor / add_embeddings.py
veerukhannan's picture
Create add_embeddings.py
6702977 verified
import os
import chromadb
from sentence_transformers import SentenceTransformer
from loguru import logger
class SentenceTransformerEmbeddings:
def __init__(self, model_name: str = 'all-MiniLM-L6-v2'):
self.model = SentenceTransformer(model_name)
def __call__(self, input: list[str]) -> list[list[float]]:
embeddings = self.model.encode(input)
return embeddings.tolist()
def load_documents():
"""Load and process documents into ChromaDB"""
try:
# Set up paths
base_path = os.path.dirname(os.path.abspath(__file__))
doc_path = os.path.join(base_path, 'a2023-45.txt')
index_path = os.path.join(base_path, 'index.txt')
chroma_path = os.path.join(base_path, 'chroma_db')
# Ensure ChromaDB directory exists
os.makedirs(chroma_path, exist_ok=True)
logger.info(f"Loading documents from {doc_path} and {index_path}")
# Initialize ChromaDB
chroma_client = chromadb.PersistentClient(path=chroma_path)
embedding_function = SentenceTransformerEmbeddings()
# Create new collection (delete if exists)
if "legal_documents" in [col.name for col in chroma_client.list_collections()]:
chroma_client.delete_collection("legal_documents")
collection = chroma_client.create_collection(
name="legal_documents",
embedding_function=embedding_function
)
# Read and validate files
with open(doc_path, 'r', encoding='utf-8') as f:
document = f.read().strip()
with open(index_path, 'r', encoding='utf-8') as f:
index_content = [line.strip() for line in f.readlines() if line.strip()]
# Process document into sections
sections = []
current_section = ""
current_title = ""
for line in document.split('\n'):
line = line.strip()
if any(index_line in line for index_line in index_content):
if current_section and current_title:
sections.append({
"title": current_title,
"content": current_section.strip()
})
current_title = line
current_section = ""
else:
if line:
current_section += line + "\n"
# Add final section
if current_section and current_title:
sections.append({
"title": current_title,
"content": current_section.strip()
})
# Prepare data for ChromaDB
documents = []
metadatas = []
ids = []
for i, section in enumerate(sections):
if section["content"].strip():
documents.append(section["content"])
metadatas.append({
"title": section["title"],
"source": "a2023-45.txt",
"section_number": i + 1
})
ids.append(f"section_{i+1}")
# Add to ChromaDB
collection.add(
documents=documents,
metadatas=metadatas,
ids=ids
)
logger.info(f"Successfully loaded {len(documents)} sections into ChromaDB")
return True
except Exception as e:
logger.error(f"Error loading documents: {str(e)}")
return False
if __name__ == "__main__":
success = load_documents()
if success:
print("Documents successfully loaded into ChromaDB")
else:
print("Failed to load documents into ChromaDB")