import gradio as gr import spaces import torch import os import tempfile import sqlite3 import json import hashlib from pathlib import Path from typing import List, Dict, Any, Tuple import PyPDF2 import docx import fitz # pymupdf from unstructured.partition.auto import partition os.environ["TRITON_CACHE_DIR"] = "/tmp/triton_cache" os.environ["TORCH_COMPILE_DISABLE"] = "1" # PyLate imports from pylate import models, indexes, retrieve # Global variables for PyLate components model = None index = None retriever = None metadata_db = None # ===== DOCUMENT PROCESSING FUNCTIONS ===== def extract_text_from_pdf(file_path: str) -> str: """Extract text from PDF file.""" text = "" try: # Try PyMuPDF first (better for complex PDFs) doc = fitz.open(file_path) for page in doc: text += page.get_text() + "\n" doc.close() except: # Fallback to PyPDF2 try: with open(file_path, 'rb') as file: pdf_reader = PyPDF2.PdfReader(file) for page in pdf_reader.pages: text += page.extract_text() + "\n" except: # Last resort: unstructured try: elements = partition(filename=file_path) text = "\n".join([str(element) for element in elements]) except: text = "Error: Could not extract text from PDF" return text.strip() def extract_text_from_docx(file_path: str) -> str: """Extract text from DOCX file.""" try: doc = docx.Document(file_path) text = "" for paragraph in doc.paragraphs: text += paragraph.text + "\n" return text.strip() except: return "Error: Could not extract text from DOCX" def extract_text_from_txt(file_path: str) -> str: """Extract text from TXT file.""" try: with open(file_path, 'r', encoding='utf-8') as file: return file.read().strip() except: try: with open(file_path, 'r', encoding='latin1') as file: return file.read().strip() except: return "Error: Could not read text file" def chunk_text(text: str, chunk_size: int = 1000, overlap: int = 100) -> List[Dict[str, Any]]: """Chunk text with overlap and return metadata.""" chunks = [] start = 0 chunk_index = 0 while start < len(text): end = start + chunk_size chunk_text = text[start:end] # Try to break at sentence boundary if end < len(text): last_period = chunk_text.rfind('.') last_newline = chunk_text.rfind('\n') break_point = max(last_period, last_newline) if break_point > chunk_size * 0.7: chunk_text = chunk_text[:break_point + 1] end = start + break_point + 1 if chunk_text.strip(): chunks.append({ 'text': chunk_text.strip(), 'start': start, 'end': end, 'index': chunk_index, 'length': len(chunk_text.strip()) }) chunk_index += 1 start = max(start + 1, end - overlap) return chunks # ===== METADATA DATABASE ===== def init_metadata_db(): """Initialize SQLite database for metadata.""" global metadata_db db_path = "metadata.db" metadata_db = sqlite3.connect(db_path, check_same_thread=False) metadata_db.execute(""" CREATE TABLE IF NOT EXISTS documents ( doc_id TEXT PRIMARY KEY, filename TEXT NOT NULL, file_hash TEXT NOT NULL, original_text TEXT NOT NULL, chunk_index INTEGER NOT NULL, total_chunks INTEGER NOT NULL, chunk_start INTEGER NOT NULL, chunk_end INTEGER NOT NULL, chunk_size INTEGER NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) metadata_db.execute(""" CREATE INDEX IF NOT EXISTS idx_filename ON documents(filename); """) metadata_db.commit() def add_document_metadata(doc_id: str, filename: str, file_hash: str, original_text: str, chunk_info: Dict[str, Any], total_chunks: int): """Add document metadata to database.""" global metadata_db metadata_db.execute(""" INSERT OR REPLACE INTO documents (doc_id, filename, file_hash, original_text, chunk_index, total_chunks, chunk_start, chunk_end, chunk_size) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( doc_id, filename, file_hash, original_text, chunk_info['index'], total_chunks, chunk_info['start'], chunk_info['end'], chunk_info['length'] )) metadata_db.commit() def get_document_metadata(doc_id: str) -> Dict[str, Any]: """Get document metadata by ID.""" global metadata_db cursor = metadata_db.execute( "SELECT * FROM documents WHERE doc_id = ?", (doc_id,) ) row = cursor.fetchone() if row: columns = [desc[0] for desc in cursor.description] return dict(zip(columns, row)) return {} # ===== PYLATE INITIALIZATION ===== @spaces.GPU def initialize_pylate(model_name: str = "lightonai/GTE-ModernColBERT-v1") -> str: """Initialize PyLate components on GPU.""" global model, index, retriever try: # Initialize metadata database init_metadata_db() # Load ColBERT model model = models.ColBERT(model_name_or_path=model_name) # Move to GPU if available if torch.cuda.is_available(): model = model.to('cuda') # Initialize PLAID index with CPU fallback for k-means index = indexes.PLAID( index_folder="./pylate_index", index_name="documents", override=True, kmeans_niters=1, # Reduce k-means iterations nbits=1 # Reduce quantization bits ) # Initialize retriever retriever = retrieve.ColBERT(index=index) return f"✅ PyLate initialized successfully!\nModel: {model_name}\nDevice: {'GPU' if torch.cuda.is_available() else 'CPU'}" except Exception as e: return f"❌ Error initializing PyLate: {str(e)}" # ===== DOCUMENT PROCESSING ===== @spaces.GPU def process_documents(files, chunk_size: int = 1000, overlap: int = 100) -> str: """Process uploaded documents and add to index.""" global model, index, metadata_db if not model or not index: return "❌ Please initialize PyLate first!" if not files: return "❌ No files uploaded!" try: all_documents = [] all_doc_ids = [] processed_files = [] for file in files: # Get file info filename = Path(file.name).name file_path = file.name # Calculate file hash with open(file_path, 'rb') as f: file_hash = hashlib.md5(f.read()).hexdigest() # Extract text based on file type if filename.lower().endswith('.pdf'): text = extract_text_from_pdf(file_path) elif filename.lower().endswith('.docx'): text = extract_text_from_docx(file_path) elif filename.lower().endswith('.txt'): text = extract_text_from_txt(file_path) else: continue if not text or text.startswith("Error:"): continue # Chunk the text chunks = chunk_text(text, chunk_size, overlap) # Process each chunk for chunk in chunks: doc_id = f"{filename}_chunk_{chunk['index']}" all_documents.append(chunk['text']) all_doc_ids.append(doc_id) # Store metadata add_document_metadata( doc_id=doc_id, filename=filename, file_hash=file_hash, original_text=chunk['text'], chunk_info=chunk, total_chunks=len(chunks) ) processed_files.append(f"{filename}: {len(chunks)} chunks") if not all_documents: return "❌ No text could be extracted from uploaded files!" # Encode documents with PyLate document_embeddings = model.encode( all_documents, batch_size=16, # Smaller batch for ZeroGPU is_query=False, show_progress_bar=True ) # Add to PLAID index index.add_documents( documents_ids=all_doc_ids, documents_embeddings=document_embeddings ) result = f"✅ Successfully processed {len(files)} files:\n" result += f"📄 Total chunks: {len(all_documents)}\n" result += f"🔍 Indexed documents:\n" for file_info in processed_files: result += f" • {file_info}\n" return result except Exception as e: return f"❌ Error processing documents: {str(e)}" # ===== SEARCH FUNCTION ===== @spaces.GPU def search_documents(query: str, k: int = 5, show_chunks: bool = True) -> str: """Search documents using PyLate.""" global model, retriever, metadata_db if not model or not retriever: return "❌ Please initialize PyLate and process documents first!" if not query.strip(): return "❌ Please enter a search query!" try: # Encode query query_embedding = model.encode([query], is_query=True) # Search results = retriever.retrieve(query_embedding, k=k)[0] if not results: return "🔍 No results found for your query." # Format results with metadata formatted_results = [f"🔍 **Search Results for:** '{query}'\n"] for i, result in enumerate(results): doc_id = result['id'] score = result['score'] # Get metadata metadata = get_document_metadata(doc_id) formatted_results.append(f"## Result {i+1} (Score: {score:.2f})") formatted_results.append( f"**File:** {metadata.get('filename', 'Unknown')}") formatted_results.append( f"**Chunk:** {metadata.get('chunk_index', 0) + 1}/{metadata.get('total_chunks', 1)}") if show_chunks: text = metadata.get('original_text', '') preview = text[:300] + "..." if len(text) > 300 else text formatted_results.append(f"**Text:** {preview}") formatted_results.append("---") return "\n".join(formatted_results) except Exception as e: return f"❌ Error searching: {str(e)}" # ===== GRADIO INTERFACE ===== def create_interface(): """Create the Gradio interface.""" with gr.Blocks(title="PyLate Document Search", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🔍 PyLate Document Search ### Powered by ColBERT and ZeroGPU H100 Upload documents, process them with PyLate, and perform semantic search! """) with gr.Tab("🚀 Setup"): gr.Markdown("### Initialize PyLate System") model_choice = gr.Dropdown( choices=[ # "lightonai/GTE-ModernColBERT-v1", "colbert-ir/colbertv2.0", "sentence-transformers/all-MiniLM-L6-v2" ], value="lightonai/GTE-ModernColBERT-v1", label="Select Model" ) init_btn = gr.Button("Initialize PyLate", variant="primary") init_status = gr.Textbox(label="Initialization Status", lines=3) init_btn.click( initialize_pylate, inputs=model_choice, outputs=init_status ) with gr.Tab("📄 Document Upload"): gr.Markdown("### Upload and Process Documents") with gr.Row(): with gr.Column(): file_upload = gr.File( file_count="multiple", file_types=[".pdf", ".docx", ".txt"], label="Upload Documents (PDF, DOCX, TXT)" ) with gr.Row(): chunk_size = gr.Slider( minimum=500, maximum=3000, value=1000, step=100, label="Chunk Size (characters)" ) overlap = gr.Slider( minimum=0, maximum=500, value=100, step=50, label="Chunk Overlap (characters)" ) process_btn = gr.Button( "Process Documents", variant="primary") with gr.Column(): process_status = gr.Textbox( label="Processing Status", lines=10, max_lines=15 ) process_btn.click( process_documents, inputs=[file_upload, chunk_size, overlap], outputs=process_status ) with gr.Tab("🔍 Search"): gr.Markdown("### Search Your Documents") with gr.Row(): with gr.Column(): search_query = gr.Textbox( label="Search Query", placeholder="Enter your search query...", lines=2 ) with gr.Row(): num_results = gr.Slider( minimum=1, maximum=20, value=5, step=1, label="Number of Results" ) show_chunks = gr.Checkbox( value=True, label="Show Text Chunks" ) search_btn = gr.Button("Search", variant="primary") with gr.Column(): search_results = gr.Textbox( label="Search Results", lines=15, max_lines=20 ) search_btn.click( search_documents, inputs=[search_query, num_results, show_chunks], outputs=search_results ) with gr.Tab("ℹ️ Info"): gr.Markdown(""" ### About This System **PyLate Document Search** is a semantic search system that uses: - **PyLate**: A flexible library for ColBERT models - **ColBERT**: Late interaction retrieval for high-quality search - **ZeroGPU**: Hugging Face's free H100 GPU infrastructure #### Features: - 📄 Multi-format document support (PDF, DOCX, TXT) - ✂️ Intelligent text chunking with overlap - 🧠 Semantic search using ColBERT embeddings - 💾 Metadata tracking for result context - ⚡ GPU-accelerated processing #### Usage Tips: 1. Initialize the system first (required) 2. Upload your documents and process them 3. Use natural language queries for best results 4. Adjust chunk size based on your document types #### Model Information: - **GTE-ModernColBERT**: Latest high-performance model - **ColBERTv2**: Original Stanford implementation - **MiniLM**: Faster, smaller model for quick testing Built with ❤️ using PyLate and Gradio """) return demo # ===== MAIN ===== if __name__ == "__main__": demo = create_interface() demo.launch( share=False, server_name="0.0.0.0", server_port=7860 )