File size: 7,363 Bytes
b9756ef 0634f1a b9756ef 0634f1a b9756ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import os
import fitz # PyMuPDF
import re
import chromadb
from chromadb.utils import embedding_functions
import numpy as np
import torch
import logging
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class VectorDatabase:
"""Vector database for storing and retrieving tenant rights information from PDF."""
def __init__(self, persist_directory="./chroma_db"):
"""Initialize the vector database."""
logging.info("Initializing VectorDatabase")
logging.info(f"NumPy version: {np.__version__}")
logging.info(f"PyTorch version: {torch.__version__}")
self.persist_directory = persist_directory
os.makedirs(persist_directory, exist_ok=True)
try:
logging.info("Creating embedding function")
self.embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="all-MiniLM-L6-v2"
)
logging.info("Initializing ChromaDB client")
self.client = chromadb.PersistentClient(path=persist_directory)
logging.info("Setting up collections")
self.document_collection = self._get_or_create_collection("tenant_documents")
self.state_collection = self._get_or_create_collection("tenant_states")
except Exception as e:
logging.error(f"Initialization failed: {str(e)}")
raise
def _get_or_create_collection(self, name):
"""Get or create a collection with the given name."""
try:
return self.client.get_collection(
name=name,
embedding_function=self.embedding_function
)
except Exception:
return self.client.create_collection(
name=name,
embedding_function=self.embedding_function
)
def extract_pdf_content(self, pdf_path):
"""Extract content from PDF file and identify state sections."""
logging.info(f"Extracting content from PDF: {pdf_path}")
if not os.path.exists(pdf_path):
raise FileNotFoundError(f"PDF file not found: {pdf_path}")
doc = fitz.open(pdf_path)
full_text = ""
for page_num in range(len(doc)):
page = doc.load_page(page_num)
full_text += page.get_text("text") + "\n"
doc.close()
state_pattern = r"(?m)^\s*([A-Z][a-z]+(?:\s[A-Z][a-z]+)*)\s+Landlord(?:-|\s)Tenant\s+(?:Law|Laws)"
state_matches = list(re.finditer(state_pattern, full_text))
if not state_matches:
logging.info("No state sections found. Treating as single document.")
return {"Full Document": full_text.strip()}
state_sections = {}
for i, match in enumerate(state_matches):
state_name = match.group(1).strip()
start_pos = match.end()
end_pos = state_matches[i + 1].start() if i + 1 < len(state_matches) else len(full_text)
state_text = full_text[start_pos:end_pos].strip()
if state_text:
state_sections[state_name] = state_text
logging.info(f"Extracted content for {len(state_sections)} states")
return state_sections
def process_and_load_pdf(self, pdf_path):
"""Process PDF and load content into vector database."""
state_sections = self.extract_pdf_content(pdf_path)
doc_ids = self.document_collection.get()["ids"]
state_ids = self.state_collection.get()["ids"]
if doc_ids:
self.document_collection.delete(ids=doc_ids)
if state_ids:
self.state_collection.delete(ids=state_ids)
document_ids, document_texts, document_metadatas = [], [], []
state_ids, state_texts, state_metadatas = [], [], []
for state, text in state_sections.items():
state_id = f"state_{state.lower().replace(' ', '_')}"
summary = text[:1000].strip() if len(text) > 1000 else text
state_ids.append(state_id)
state_texts.append(summary)
state_metadatas.append({"state": state, "type": "summary"})
chunks = self._chunk_text(text, chunk_size=1000, overlap=200)
for i, chunk in enumerate(chunks):
doc_id = f"doc_{state.lower().replace(' ', '_')}_{i}"
document_ids.append(doc_id)
document_texts.append(chunk)
document_metadatas.append({
"state": state,
"chunk_id": i,
"total_chunks": len(chunks),
"source": os.path.basename(pdf_path)
})
if document_ids:
self.document_collection.add(
ids=document_ids,
documents=document_texts,
metadatas=document_metadatas
)
if state_ids:
self.state_collection.add(
ids=state_ids,
documents=state_texts,
metadatas=state_metadatas
)
logging.info(f"Loaded {len(document_ids)} document chunks and {len(state_ids)} state summaries")
return len(state_sections)
def _chunk_text(self, text, chunk_size=1000, overlap=200):
"""Split text into overlapping chunks."""
if not text:
return []
chunks = []
start = 0
text_length = len(text)
while start < text_length:
end = min(start + chunk_size, text_length)
if end < text_length:
last_period = text.rfind(".", start, end)
last_newline = text.rfind("\n", start, end)
split_point = max(last_period, last_newline)
if split_point > start:
end = split_point + 1
chunks.append(text[start:end].strip())
start = end - overlap if end - overlap > start else end
return chunks
def query(self, query_text, state=None, n_results=5):
"""Query the vector database for relevant tenant rights information."""
state_filter = {"state": state} if state else None
document_results = self.document_collection.query(
query_texts=[query_text],
n_results=n_results,
where=state_filter
)
state_results = self.state_collection.query(
query_texts=[query_text],
n_results=n_results,
where=state_filter
)
return {"document_results": document_results, "state_results": state_results}
def get_states(self):
"""Get a list of all states in the database."""
results = self.state_collection.get()
states = {meta["state"] for meta in results["metadatas"] if meta}
return sorted(list(states))
if __name__ == "__main__":
try:
db = VectorDatabase()
pdf_path = "tenant-landlord.pdf"
db.process_and_load_pdf(pdf_path)
states = db.get_states()
print(f"Available states: {states}")
except Exception as e:
logging.error(f"Script execution failed: {str(e)}")
raise |