import sys import os import boto3 import hashlib import json import threading # Add the project root directory to Python path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from typing import List from langchain.text_splitter import RecursiveCharacterTextSplitter from concurrent.futures import ThreadPoolExecutor, as_completed from langchain_openai import OpenAIEmbeddings import pinecone from tqdm.auto import tqdm from langchain.schema import Document from config import get_settings from dotenv import load_dotenv from io import BytesIO from PyPDF2 import PdfReader load_dotenv() class RAGPrep: def __init__(self, processed_hashes_file="processed_hashes.json"): self.settings = get_settings() self.index_name = self.settings.INDEX_NAME self.pc = self.init_pinecone() self.embeddings = OpenAIEmbeddings(openai_api_key=self.settings.OPENAI_API_KEY) self.processed_hashes_file = processed_hashes_file self.processed_hashes = self.load_processed_hashes() def init_pinecone(self): """Initialize Pinecone client""" pc = pinecone.Pinecone(self.settings.PINECONE_API_KEY) return pc # Define function to create or connect to an existing index def create_or_connect_index(self,index_name, dimension): """Create or connect to existing Pinecone index""" spec = pinecone.ServerlessSpec( cloud=self.settings.CLOUD, region=self.settings.REGION ) print(f'all indexes: {self.pc.list_indexes()}') if index_name not in self.pc.list_indexes().names(): self.pc.create_index( name=index_name, dimension=dimension, metric='cosine', # You can use 'dotproduct' or other metrics if needed spec=spec ) return self.pc.Index(index_name) def load_processed_hashes(self): """Load previously processed hashes from a file.""" if os.path.exists(self.processed_hashes_file): with open(self.processed_hashes_file, "r") as f: return set(json.load(f)) return set() def save_processed_hashes(self): """Save processed hashes to a file.""" with open(self.processed_hashes_file, "w") as f: json.dump(list(self.processed_hashes), f) def generate_pdf_hash(self, pdf_content: bytes): """Generate a hash for the given PDF content.""" hasher = hashlib.md5() hasher.update(pdf_content) return hasher.hexdigest() def load_and_split_pdfs(self, chunk_from = 50, chunk_to = 100) -> List[Document]: """Load PDFs from S3, extract text, and split into chunks.""" print("***********") # Initialize S3 client s3_client = boto3.client( 's3', aws_access_key_id=self.settings.AWS_ACCESS_KEY, aws_secret_access_key=self.settings.AWS_SECRET_KEY, region_name=self.settings.AWS_REGION ) # List all PDF files in the S3 bucket and prefix print(f"Listing files in S3 bucket: {self.settings.AWS_BUCKET_NAME}") response = s3_client.list_objects_v2(Bucket=self.settings.AWS_BUCKET_NAME, Prefix="") s3_keys = [obj['Key'] for obj in response.get('Contents', [])] print(f"Found {len(s3_keys)} PDF files in S3") documents = [] # Process each PDF file for s3_key in s3_keys[chunk_from:chunk_to]: print(f"Processing file: {s3_key}") if not s3_key.lower().endswith(".pdf"): print("Not a PDF file, skipping.") continue try: # Read file from S3 obj = s3_client.get_object(Bucket=self.settings.AWS_BUCKET_NAME, Key=s3_key) pdf_content = obj['Body'].read() # Generate hash and check for duplicates pdf_hash = self.generate_pdf_hash(pdf_content) if pdf_hash in self.processed_hashes: print(f"Duplicate PDF detected: {s3_key}, skipping.") continue # Extract text from PDF pdf_file = BytesIO(pdf_content) pdf_reader = PdfReader(pdf_file) text = "".join(page.extract_text() for page in pdf_reader.pages) # Add document with metadata documents.append(Document(page_content=text, metadata={"source": s3_key})) self.processed_hashes.add(pdf_hash) except Exception as e: print(f"Error processing {s3_key}: {e}") print(f"Extracted text from {len(documents)} documents") # Split documents into chunks text_splitter = RecursiveCharacterTextSplitter( chunk_size=self.settings.CHUNK_SIZE, chunk_overlap=self.settings.CHUNK_OVERLAP ) chunks = text_splitter.split_documents(documents) print(f"Created {len(chunks)} chunks") # Save updated hashes self.save_processed_hashes() return chunks def process_and_upload(self, total_batch=200): """Process PDFs and upload to Pinecone""" # Create or connect to index index = self.create_or_connect_index(self.index_name, self.settings.DIMENSIONS) # Load and split documents print(f'//////// chunking: ////////') current_batch = 0 for i in range(0, total_batch, 50): batch_size = 50 # Adjust based on your needs chunks = self.load_and_split_pdfs(current_batch, current_batch+batch_size) current_batch = current_batch + batch_size # Prepare for batch processing max_threads = 4 # Adjust based on your hardware def process_batch(batch, batch_index): """Process a single batch of chunks""" print(f"Processing batch {batch_index} on thread: {threading.current_thread().name}") print(f"Active threads: {threading.active_count()}") # Create ids for batch ids = [f"chunk_{batch_index}_{j}" for j in range(len(batch))] # Get texts and generate embeddings texts = [doc.page_content for doc in batch] embeddings = self.embeddings.embed_documents(texts) # Create metadata metadata = [ { "text": doc.page_content, "source": doc.metadata.get("source", "unknown"), "page": doc.metadata.get("page", 0) } for doc in batch ] # Create upsert batch return list(zip(ids, embeddings, metadata)) with ThreadPoolExecutor(max_threads) as executor: futures = [] print(f"Batch size being used: {batch_size}") for i in range(0, len(chunks), batch_size): batch = chunks[i:i + batch_size] futures.append(executor.submit(process_batch, batch, i)) # Gather results and upsert to Pinecone for future in tqdm(as_completed(futures), total=len(futures), desc="Uploading batches"): try: to_upsert = future.result() index.upsert(vectors=to_upsert) except Exception as e: print(f"Error processing batch: {e}") print(f"Successfully processed and uploaded {len(chunks)} chunks to Pinecone") def cleanup_index(self) -> bool: """ Delete all vectors from the Pinecone index. Returns: bool: True if cleanup was successful, False otherwise Raises: Exception: Logs any unexpected errors during cleanup """ try: # Try to get the index if self.index_name in self.pc.list_indexes().names(): print(f'index name found in {self.pc.list_indexes().names()}') # Attempt to delete all vectors index = self.pc.Index(self.index_name) index.delete(delete_all=True) print(f"Successfully cleaned up index: {self.index_name}") return True print(f'Index doesn\'t exist.') return True except Exception as e: print(f"Unexpected error during index cleanup: {str(e)}") # You might want to log this error as well import logging logging.error(f"Failed to cleanup index {self.index_name}. Error: {str(e)}") return False finally: # Any cleanup code that should run regardless of success/failure print("Cleanup operation completed.") # Example usage: if __name__ == "__main__": # Example .env file content: rag_prep = RAGPrep() rag_prep.process_and_upload() # rag_prep.cleanup_index()