import gradio as gr import spaces import torch import os import tempfile import sqlite3 import json import hashlib from pathlib import Path from typing import List, Dict, Any, Tuple import docx import fitz # pymupdf from unstructured.partition.auto import partition os.environ["TRITON_CACHE_DIR"] = "/tmp/triton_cache" os.environ["TORCH_COMPILE_DISABLE"] = "1" # PyLate imports from pylate import models, indexes, retrieve # Global variables for PyLate components model = None index = None retriever = None metadata_db = None # ===== DOCUMENT PROCESSING FUNCTIONS ===== def extract_text_from_pdf(file_path: str) -> str: """Extract text from PDF file using PyMuPDF and unstructured as fallback.""" text = "" try: # Use PyMuPDF (fitz) - more reliable than PyPDF2 doc = fitz.open(file_path) for page in doc: text += page.get_text() + "\n" doc.close() # If no text extracted, try unstructured if not text.strip(): elements = partition(filename=file_path) text = "\n".join([str(element) for element in elements]) except Exception as e: # Final fallback to unstructured try: elements = partition(filename=file_path) text = "\n".join([str(element) for element in elements]) except: text = f"Error: Could not extract text from PDF: {str(e)}" return text.strip() def extract_text_from_docx(file_path: str) -> str: """Extract text from DOCX file.""" try: doc = docx.Document(file_path) text = "" for paragraph in doc.paragraphs: text += paragraph.text + "\n" return text.strip() except Exception as e: return f"Error: Could not extract text from DOCX: {str(e)}" def extract_text_from_txt(file_path: str) -> str: """Extract text from TXT file.""" try: with open(file_path, 'r', encoding='utf-8') as file: return file.read().strip() except: try: with open(file_path, 'r', encoding='latin1') as file: return file.read().strip() except Exception as e: return f"Error: Could not read text file: {str(e)}" def chunk_text(text: str, chunk_size: int = 1000, overlap: int = 100) -> List[Dict[str, Any]]: """Chunk text with overlap and return metadata.""" chunks = [] start = 0 chunk_index = 0 while start < len(text): end = start + chunk_size chunk_text = text[start:end] # Try to break at sentence boundary if end < len(text): last_period = chunk_text.rfind('.') last_newline = chunk_text.rfind('\n') break_point = max(last_period, last_newline) if break_point > chunk_size * 0.7: chunk_text = chunk_text[:break_point + 1] end = start + break_point + 1 if chunk_text.strip(): chunks.append({ 'text': chunk_text.strip(), 'start': start, 'end': end, 'index': chunk_index, 'length': len(chunk_text.strip()) }) chunk_index += 1 start = max(start + 1, end - overlap) return chunks # ===== METADATA DATABASE ===== def init_metadata_db(): """Initialize SQLite database for metadata.""" global metadata_db db_path = "metadata.db" metadata_db = sqlite3.connect(db_path, check_same_thread=False) metadata_db.execute(""" CREATE TABLE IF NOT EXISTS documents ( doc_id TEXT PRIMARY KEY, filename TEXT NOT NULL, file_hash TEXT NOT NULL, original_text TEXT NOT NULL, chunk_index INTEGER NOT NULL, total_chunks INTEGER NOT NULL, chunk_start INTEGER NOT NULL, chunk_end INTEGER NOT NULL, chunk_size INTEGER NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) metadata_db.execute(""" CREATE INDEX IF NOT EXISTS idx_filename ON documents(filename); """) metadata_db.commit() def add_document_metadata(doc_id: str, filename: str, file_hash: str, original_text: str, chunk_info: Dict[str, Any], total_chunks: int): """Add document metadata to database.""" global metadata_db metadata_db.execute(""" INSERT OR REPLACE INTO documents (doc_id, filename, file_hash, original_text, chunk_index, total_chunks, chunk_start, chunk_end, chunk_size) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( doc_id, filename, file_hash, original_text, chunk_info['index'], total_chunks, chunk_info['start'], chunk_info['end'], chunk_info['length'] )) metadata_db.commit() def get_document_metadata(doc_id: str) -> Dict[str, Any]: """Get document metadata by ID.""" global metadata_db cursor = metadata_db.execute( "SELECT * FROM documents WHERE doc_id = ?", (doc_id,) ) row = cursor.fetchone() if row: columns = [desc[0] for desc in cursor.description] return dict(zip(columns, row)) return {} # ===== PYLATE INITIALIZATION ===== @spaces.GPU def initialize_pylate(model_name: str = "colbert-ir/colbertv2.0") -> str: """Initialize PyLate components on GPU.""" global model, index, retriever try: # Initialize metadata database init_metadata_db() # Load ColBERT model model = models.ColBERT(model_name_or_path=model_name) # Move to GPU if available if torch.cuda.is_available(): model = model.to('cuda') # Initialize PLAID index with CPU fallback for k-means index = indexes.PLAID( index_folder="./pylate_index", index_name="documents", override=True, kmeans_niters=1, # Reduce k-means iterations nbits=1 # Reduce quantization bits ) # Initialize retriever retriever = retrieve.ColBERT(index=index) return f"✅ PyLate initialized successfully!\nModel: {model_name}\nDevice: {'GPU' if torch.cuda.is_available() else 'CPU'}" except Exception as e: return f"❌ Error initializing PyLate: {str(e)}" # ===== DOCUMENT PROCESSING ===== @spaces.GPU def process_documents(files, chunk_size: int = 1000, overlap: int = 100) -> str: """Process uploaded documents and add to index.""" global model, index, metadata_db if not model or not index: return "❌ Please initialize PyLate first!" if not files: return "❌ No files uploaded!" try: all_documents = [] all_doc_ids = [] processed_files = [] for file in files: # Get file info filename = Path(file.name).name file_path = file.name # Calculate file hash with open(file_path, 'rb') as f: file_hash = hashlib.md5(f.read()).hexdigest() # Extract text based on file type if filename.lower().endswith('.pdf'): text = extract_text_from_pdf(file_path) elif filename.lower().endswith('.docx'): text = extract_text_from_docx(file_path) elif filename.lower().endswith('.txt'): text = extract_text_from_txt(file_path) else: continue if not text or text.startswith("Error:"): processed_files.append(f"{filename}: Failed to extract text") continue # Chunk the text chunks = chunk_text(text, chunk_size, overlap) # Process each chunk for chunk in chunks: doc_id = f"{filename}_chunk_{chunk['index']}" all_documents.append(chunk['text']) all_doc_ids.append(doc_id) # Store metadata add_document_metadata( doc_id=doc_id, filename=filename, file_hash=file_hash, original_text=chunk['text'], chunk_info=chunk, total_chunks=len(chunks) ) processed_files.append(f"{filename}: {len(chunks)} chunks") if not all_documents: return "❌ No text could be extracted from uploaded files!" # Encode documents with PyLate document_embeddings = model.encode( all_documents, batch_size=16, # Smaller batch for ZeroGPU is_query=False, show_progress_bar=True ) # Add to PLAID index index.add_documents( documents_ids=all_doc_ids, documents_embeddings=document_embeddings ) result = f"✅ Successfully processed {len(files)} files:\n" result += f"📄 Total chunks: {len(all_documents)}\n" result += f"🔍 Indexed documents:\n" for file_info in processed_files: result += f" • {file_info}\n" return result except Exception as e: return f"❌ Error processing documents: {str(e)}" # ===== SEARCH FUNCTION ===== @spaces.GPU def search_documents(query: str, k: int = 5, show_chunks: bool = True) -> str: """Search documents using PyLate.""" global model, retriever, metadata_db if not model or not retriever: return "❌ Please initialize PyLate and process documents first!" if not query.strip(): return "❌ Please enter a search query!" try: # Encode query query_embedding = model.encode([query], is_query=True) # Search results = retriever.retrieve(query_embedding, k=k)[0] if not results: return "🔍 No results found for your query." # Format results with metadata formatted_results = [f"🔍 **Search Results for:** '{query}'\n"] for i, result in enumerate(results): doc_id = result['id'] score = result['score'] # Get metadata metadata = get_document_metadata(doc_id) formatted_results.append(f"## Result {i+1} (Score: {score:.2f})") formatted_results.append( f"**File:** {metadata.get('filename', 'Unknown')}") formatted_results.append( f"**Chunk:** {metadata.get('chunk_index', 0) + 1}/{metadata.get('total_chunks', 1)}") if show_chunks: text = metadata.get('original_text', '') preview = text[:300] + "..." if len(text) > 300 else text formatted_results.append(f"**Text:** {preview}") formatted_results.append("---") return "\n".join(formatted_results) except Exception as e: return f"❌ Error searching: {str(e)}" # ===== GRADIO INTERFACE ===== def create_interface(): """Create the Gradio interface.""" with gr.Blocks(title="PyLate Document Search", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🔍 PyLate Document Search ### Powered by ColBERT and ZeroGPU Upload documents, process them with PyLate, and perform semantic search! **Note:** Using PyMuPDF and Unstructured for robust PDF text extraction. """) with gr.Tab("🚀 Setup"): gr.Markdown("### Initialize PyLate System") model_choice = gr.Dropdown( choices=[ "colbert-ir/colbertv2.0", "sentence-transformers/all-MiniLM-L6-v2" ], value="colbert-ir/colbertv2.0", label="Select Model" ) init_btn = gr.Button("Initialize PyLate", variant="primary") init_status = gr.Textbox(label="Initialization Status", lines=3) init_btn.click( initialize_pylate, inputs=model_choice, outputs=init_status ) with gr.Tab("📄 Document Upload"): gr.Markdown("### Upload and Process Documents") with gr.Row(): with gr.Column(): file_upload = gr.File( file_count="multiple", file_types=[".pdf", ".docx", ".txt"], label="Upload Documents (PDF, DOCX, TXT)" ) with gr.Row(): chunk_size = gr.Slider( minimum=500, maximum=3000, value=1000, step=100, label="Chunk Size (characters)" ) overlap = gr.Slider( minimum=0, maximum=500, value=100, step=50, label="Chunk Overlap (characters)" ) process_btn = gr.Button( "Process Documents", variant="primary") with gr.Column(): process_status = gr.Textbox( label="Processing Status", lines=10, max_lines=15 ) process_btn.click( process_documents, inputs=[file_upload, chunk_size, overlap], outputs=process_status ) with gr.Tab("🔍 Search"): gr.Markdown("### Search Your Documents") with gr.Row(): with gr.Column(): search_query = gr.Textbox( label="Search Query", placeholder="Enter your search query...", lines=2 ) with gr.Row(): num_results = gr.Slider( minimum=1, maximum=20, value=5, step=1, label="Number of Results" ) show_chunks = gr.Checkbox( value=True, label="Show Text Chunks" ) search_btn = gr.Button("Search", variant="primary") with gr.Column(): search_results = gr.Textbox( label="Search Results", lines=15, max_lines=20 ) search_btn.click( search_documents, inputs=[search_query, num_results, show_chunks], outputs=search_results ) with gr.Tab("ℹ️ Info"): gr.Markdown(""" ### About This System **PyLate Document Search** is a semantic search system that uses: - **PyLate**: A flexible library for ColBERT models - **ColBERT**: Late interaction retrieval for high-quality search - **ZeroGPU**: Hugging Face's free GPU infrastructure #### Features: - 📄 Multi-format document support (PDF, DOCX, TXT) - ✂️ Intelligent text chunking with overlap - 🧠 Semantic search using ColBERT embeddings - 💾 Metadata tracking for result context - ⚡ GPU-accelerated processing #### PDF Processing: - Uses PyMuPDF (fitz) for reliable text extraction - Falls back to Unstructured for complex PDFs - No dependency on PyPDF2 #### Usage Tips: 1. Initialize the system first (required) 2. Upload your documents and process them 3. Use natural language queries for best results 4. Adjust chunk size based on your document types Built with ❤️ using PyLate and Gradio """) return demo # ===== MAIN ===== if __name__ == "__main__": demo = create_interface() demo.launch( share=False, server_name="0.0.0.0", server_port=7860 )