|
import os |
|
import fitz |
|
import re |
|
import chromadb |
|
from chromadb.utils import embedding_functions |
|
import numpy as np |
|
import torch |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
class VectorDatabase: |
|
"""Vector database for storing and retrieving tenant rights information from PDF.""" |
|
|
|
def __init__(self, persist_directory="./data/chroma_db"): |
|
"""Initialize the vector database.""" |
|
logging.info("Initializing VectorDatabase") |
|
logging.info(f"NumPy version: {np.__version__}") |
|
logging.info(f"PyTorch version: {torch.__version__}") |
|
|
|
self.persist_directory = persist_directory |
|
os.makedirs(persist_directory, exist_ok=True) |
|
|
|
try: |
|
logging.info("Creating embedding function") |
|
self.embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction( |
|
model_name="all-MiniLM-L6-v2" |
|
) |
|
|
|
logging.info("Initializing ChromaDB client") |
|
self.client = chromadb.PersistentClient(path=persist_directory) |
|
|
|
logging.info("Setting up collections") |
|
self.document_collection = self._get_or_create_collection("tenant_documents") |
|
self.state_collection = self._get_or_create_collection("tenant_states") |
|
except Exception as e: |
|
logging.error(f"Initialization failed: {str(e)}") |
|
raise |
|
|
|
def _get_or_create_collection(self, name): |
|
"""Get or create a collection with the given name.""" |
|
try: |
|
return self.client.get_collection( |
|
name=name, |
|
embedding_function=self.embedding_function |
|
) |
|
except Exception: |
|
return self.client.create_collection( |
|
name=name, |
|
embedding_function=self.embedding_function |
|
) |
|
|
|
def extract_pdf_content(self, pdf_path): |
|
"""Extract content from PDF file and identify state sections.""" |
|
logging.info(f"Extracting content from PDF: {pdf_path}") |
|
|
|
if not os.path.exists(pdf_path): |
|
raise FileNotFoundError(f"PDF file not found: {pdf_path}") |
|
|
|
doc = fitz.open(pdf_path) |
|
full_text = "" |
|
for page_num in range(len(doc)): |
|
page = doc.load_page(page_num) |
|
full_text += page.get_text("text") + "\n" |
|
doc.close() |
|
|
|
state_pattern = r"(?m)^\s*([A-Z][a-z]+(?:\s[A-Z][a-z]+)*)\s+Landlord(?:-|\s)Tenant\s+(?:Law|Laws)" |
|
state_matches = list(re.finditer(state_pattern, full_text)) |
|
|
|
if not state_matches: |
|
logging.info("No state sections found. Treating as single document.") |
|
return {"Full Document": full_text.strip()} |
|
|
|
state_sections = {} |
|
for i, match in enumerate(state_matches): |
|
state_name = match.group(1).strip() |
|
start_pos = match.end() |
|
end_pos = state_matches[i + 1].start() if i + 1 < len(state_matches) else len(full_text) |
|
state_text = full_text[start_pos:end_pos].strip() |
|
if state_text: |
|
state_sections[state_name] = state_text |
|
|
|
logging.info(f"Extracted content for {len(state_sections)} states") |
|
return state_sections |
|
|
|
def process_and_load_pdf(self, pdf_path): |
|
"""Process PDF and load content into vector database.""" |
|
state_sections = self.extract_pdf_content(pdf_path) |
|
|
|
doc_ids = self.document_collection.get()["ids"] |
|
state_ids = self.state_collection.get()["ids"] |
|
|
|
if doc_ids: |
|
self.document_collection.delete(ids=doc_ids) |
|
if state_ids: |
|
self.state_collection.delete(ids=state_ids) |
|
|
|
document_ids, document_texts, document_metadatas = [], [], [] |
|
state_ids, state_texts, state_metadatas = [], [], [] |
|
|
|
for state, text in state_sections.items(): |
|
state_id = f"state_{state.lower().replace(' ', '_')}" |
|
summary = text[:1000].strip() if len(text) > 1000 else text |
|
state_ids.append(state_id) |
|
state_texts.append(summary) |
|
state_metadatas.append({"state": state, "type": "summary"}) |
|
|
|
chunks = self._chunk_text(text, chunk_size=1000, overlap=200) |
|
for i, chunk in enumerate(chunks): |
|
doc_id = f"doc_{state.lower().replace(' ', '_')}_{i}" |
|
document_ids.append(doc_id) |
|
document_texts.append(chunk) |
|
document_metadatas.append({ |
|
"state": state, |
|
"chunk_id": i, |
|
"total_chunks": len(chunks), |
|
"source": os.path.basename(pdf_path) |
|
}) |
|
|
|
if document_ids: |
|
self.document_collection.add( |
|
ids=document_ids, |
|
documents=document_texts, |
|
metadatas=document_metadatas |
|
) |
|
if state_ids: |
|
self.state_collection.add( |
|
ids=state_ids, |
|
documents=state_texts, |
|
metadatas=state_metadatas |
|
) |
|
|
|
logging.info(f"Loaded {len(document_ids)} document chunks and {len(state_ids)} state summaries") |
|
return len(state_sections) |
|
|
|
def _chunk_text(self, text, chunk_size=1000, overlap=200): |
|
"""Split text into overlapping chunks.""" |
|
if not text: |
|
return [] |
|
|
|
chunks = [] |
|
start = 0 |
|
text_length = len(text) |
|
|
|
while start < text_length: |
|
end = min(start + chunk_size, text_length) |
|
if end < text_length: |
|
last_period = text.rfind(".", start, end) |
|
last_newline = text.rfind("\n", start, end) |
|
split_point = max(last_period, last_newline) |
|
if split_point > start: |
|
end = split_point + 1 |
|
chunks.append(text[start:end].strip()) |
|
start = end - overlap if end - overlap > start else end |
|
|
|
return chunks |
|
|
|
def query(self, query_text, state=None, n_results=5): |
|
"""Query the vector database for relevant tenant rights information.""" |
|
state_filter = {"state": state} if state else None |
|
|
|
document_results = self.document_collection.query( |
|
query_texts=[query_text], |
|
n_results=n_results, |
|
where=state_filter |
|
) |
|
state_results = self.state_collection.query( |
|
query_texts=[query_text], |
|
n_results=n_results, |
|
where=state_filter |
|
) |
|
|
|
return {"document_results": document_results, "state_results": state_results} |
|
|
|
def get_states(self): |
|
"""Get a list of all states in the database.""" |
|
results = self.state_collection.get() |
|
states = {meta["state"] for meta in results["metadatas"] if meta} |
|
return sorted(list(states)) |
|
|
|
if __name__ == "__main__": |
|
try: |
|
db = VectorDatabase() |
|
pdf_path = "data/tenant-landlord.pdf" |
|
db.process_and_load_pdf(pdf_path) |
|
states = db.get_states() |
|
print(f"Available states: {states}") |
|
except Exception as e: |
|
logging.error(f"Script execution failed: {str(e)}") |
|
raise |