fullstack's picture
.
8c651f8
raw
history blame
16.1 kB
import gradio as gr
import spaces
import torch
import os
import tempfile
import sqlite3
import json
import hashlib
from pathlib import Path
from typing import List, Dict, Any, Tuple
import docx
import fitz # pymupdf
from unstructured.partition.auto import partition
os.environ["TRITON_CACHE_DIR"] = "/tmp/triton_cache"
os.environ["TORCH_COMPILE_DISABLE"] = "1"
# PyLate imports
from pylate import models, indexes, retrieve
# Global variables for PyLate components
model = None
index = None
retriever = None
metadata_db = None
# ===== DOCUMENT PROCESSING FUNCTIONS =====
def extract_text_from_pdf(file_path: str) -> str:
"""Extract text from PDF file using PyMuPDF and unstructured as fallback."""
text = ""
try:
# Use PyMuPDF (fitz) - more reliable than PyPDF2
doc = fitz.open(file_path)
for page in doc:
text += page.get_text() + "\n"
doc.close()
# If no text extracted, try unstructured
if not text.strip():
elements = partition(filename=file_path)
text = "\n".join([str(element) for element in elements])
except Exception as e:
# Final fallback to unstructured
try:
elements = partition(filename=file_path)
text = "\n".join([str(element) for element in elements])
except:
text = f"Error: Could not extract text from PDF: {str(e)}"
return text.strip()
def extract_text_from_docx(file_path: str) -> str:
"""Extract text from DOCX file."""
try:
doc = docx.Document(file_path)
text = ""
for paragraph in doc.paragraphs:
text += paragraph.text + "\n"
return text.strip()
except Exception as e:
return f"Error: Could not extract text from DOCX: {str(e)}"
def extract_text_from_txt(file_path: str) -> str:
"""Extract text from TXT file."""
try:
with open(file_path, 'r', encoding='utf-8') as file:
return file.read().strip()
except:
try:
with open(file_path, 'r', encoding='latin1') as file:
return file.read().strip()
except Exception as e:
return f"Error: Could not read text file: {str(e)}"
def chunk_text(text: str, chunk_size: int = 1000, overlap: int = 100) -> List[Dict[str, Any]]:
"""Chunk text with overlap and return metadata."""
chunks = []
start = 0
chunk_index = 0
while start < len(text):
end = start + chunk_size
chunk_text = text[start:end]
# Try to break at sentence boundary
if end < len(text):
last_period = chunk_text.rfind('.')
last_newline = chunk_text.rfind('\n')
break_point = max(last_period, last_newline)
if break_point > chunk_size * 0.7:
chunk_text = chunk_text[:break_point + 1]
end = start + break_point + 1
if chunk_text.strip():
chunks.append({
'text': chunk_text.strip(),
'start': start,
'end': end,
'index': chunk_index,
'length': len(chunk_text.strip())
})
chunk_index += 1
start = max(start + 1, end - overlap)
return chunks
# ===== METADATA DATABASE =====
def init_metadata_db():
"""Initialize SQLite database for metadata."""
global metadata_db
db_path = "metadata.db"
metadata_db = sqlite3.connect(db_path, check_same_thread=False)
metadata_db.execute("""
CREATE TABLE IF NOT EXISTS documents (
doc_id TEXT PRIMARY KEY,
filename TEXT NOT NULL,
file_hash TEXT NOT NULL,
original_text TEXT NOT NULL,
chunk_index INTEGER NOT NULL,
total_chunks INTEGER NOT NULL,
chunk_start INTEGER NOT NULL,
chunk_end INTEGER NOT NULL,
chunk_size INTEGER NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
metadata_db.execute("""
CREATE INDEX IF NOT EXISTS idx_filename ON documents(filename);
""")
metadata_db.commit()
def add_document_metadata(doc_id: str, filename: str, file_hash: str,
original_text: str, chunk_info: Dict[str, Any], total_chunks: int):
"""Add document metadata to database."""
global metadata_db
metadata_db.execute("""
INSERT OR REPLACE INTO documents
(doc_id, filename, file_hash, original_text, chunk_index, total_chunks,
chunk_start, chunk_end, chunk_size)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
doc_id, filename, file_hash, original_text,
chunk_info['index'], total_chunks,
chunk_info['start'], chunk_info['end'], chunk_info['length']
))
metadata_db.commit()
def get_document_metadata(doc_id: str) -> Dict[str, Any]:
"""Get document metadata by ID."""
global metadata_db
cursor = metadata_db.execute(
"SELECT * FROM documents WHERE doc_id = ?", (doc_id,)
)
row = cursor.fetchone()
if row:
columns = [desc[0] for desc in cursor.description]
return dict(zip(columns, row))
return {}
# ===== PYLATE INITIALIZATION =====
@spaces.GPU
def initialize_pylate(model_name: str = "colbert-ir/colbertv2.0") -> str:
"""Initialize PyLate components on GPU."""
global model, index, retriever
try:
# Initialize metadata database
init_metadata_db()
# Load ColBERT model
model = models.ColBERT(model_name_or_path=model_name)
# Move to GPU if available
if torch.cuda.is_available():
model = model.to('cuda')
# Initialize PLAID index with CPU fallback for k-means
index = indexes.PLAID(
index_folder="./pylate_index",
index_name="documents",
override=True,
kmeans_niters=1, # Reduce k-means iterations
nbits=1 # Reduce quantization bits
)
# Initialize retriever
retriever = retrieve.ColBERT(index=index)
return f"βœ… PyLate initialized successfully!\nModel: {model_name}\nDevice: {'GPU' if torch.cuda.is_available() else 'CPU'}"
except Exception as e:
return f"❌ Error initializing PyLate: {str(e)}"
# ===== DOCUMENT PROCESSING =====
@spaces.GPU
def process_documents(files, chunk_size: int = 1000, overlap: int = 100) -> str:
"""Process uploaded documents and add to index."""
global model, index, metadata_db
if not model or not index:
return "❌ Please initialize PyLate first!"
if not files:
return "❌ No files uploaded!"
try:
all_documents = []
all_doc_ids = []
processed_files = []
for file in files:
# Get file info
filename = Path(file.name).name
file_path = file.name
# Calculate file hash
with open(file_path, 'rb') as f:
file_hash = hashlib.md5(f.read()).hexdigest()
# Extract text based on file type
if filename.lower().endswith('.pdf'):
text = extract_text_from_pdf(file_path)
elif filename.lower().endswith('.docx'):
text = extract_text_from_docx(file_path)
elif filename.lower().endswith('.txt'):
text = extract_text_from_txt(file_path)
else:
continue
if not text or text.startswith("Error:"):
processed_files.append(f"{filename}: Failed to extract text")
continue
# Chunk the text
chunks = chunk_text(text, chunk_size, overlap)
# Process each chunk
for chunk in chunks:
doc_id = f"{filename}_chunk_{chunk['index']}"
all_documents.append(chunk['text'])
all_doc_ids.append(doc_id)
# Store metadata
add_document_metadata(
doc_id=doc_id,
filename=filename,
file_hash=file_hash,
original_text=chunk['text'],
chunk_info=chunk,
total_chunks=len(chunks)
)
processed_files.append(f"{filename}: {len(chunks)} chunks")
if not all_documents:
return "❌ No text could be extracted from uploaded files!"
# Encode documents with PyLate
document_embeddings = model.encode(
all_documents,
batch_size=16, # Smaller batch for ZeroGPU
is_query=False,
show_progress_bar=True
)
# Add to PLAID index
index.add_documents(
documents_ids=all_doc_ids,
documents_embeddings=document_embeddings
)
result = f"βœ… Successfully processed {len(files)} files:\n"
result += f"πŸ“„ Total chunks: {len(all_documents)}\n"
result += f"πŸ” Indexed documents:\n"
for file_info in processed_files:
result += f" β€’ {file_info}\n"
return result
except Exception as e:
return f"❌ Error processing documents: {str(e)}"
# ===== SEARCH FUNCTION =====
@spaces.GPU
def search_documents(query: str, k: int = 5, show_chunks: bool = True) -> str:
"""Search documents using PyLate."""
global model, retriever, metadata_db
if not model or not retriever:
return "❌ Please initialize PyLate and process documents first!"
if not query.strip():
return "❌ Please enter a search query!"
try:
# Encode query
query_embedding = model.encode([query], is_query=True)
# Search
results = retriever.retrieve(query_embedding, k=k)[0]
if not results:
return "πŸ” No results found for your query."
# Format results with metadata
formatted_results = [f"πŸ” **Search Results for:** '{query}'\n"]
for i, result in enumerate(results):
doc_id = result['id']
score = result['score']
# Get metadata
metadata = get_document_metadata(doc_id)
formatted_results.append(f"## Result {i+1} (Score: {score:.2f})")
formatted_results.append(
f"**File:** {metadata.get('filename', 'Unknown')}")
formatted_results.append(
f"**Chunk:** {metadata.get('chunk_index', 0) + 1}/{metadata.get('total_chunks', 1)}")
if show_chunks:
text = metadata.get('original_text', '')
preview = text[:300] + "..." if len(text) > 300 else text
formatted_results.append(f"**Text:** {preview}")
formatted_results.append("---")
return "\n".join(formatted_results)
except Exception as e:
return f"❌ Error searching: {str(e)}"
# ===== GRADIO INTERFACE =====
def create_interface():
"""Create the Gradio interface."""
with gr.Blocks(title="PyLate Document Search", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# πŸ” PyLate Document Search
### Powered by ColBERT and ZeroGPU
Upload documents, process them with PyLate, and perform semantic search!
**Note:** Using PyMuPDF and Unstructured for robust PDF text extraction.
""")
with gr.Tab("πŸš€ Setup"):
gr.Markdown("### Initialize PyLate System")
model_choice = gr.Dropdown(
choices=[
"colbert-ir/colbertv2.0",
"sentence-transformers/all-MiniLM-L6-v2"
],
value="colbert-ir/colbertv2.0",
label="Select Model"
)
init_btn = gr.Button("Initialize PyLate", variant="primary")
init_status = gr.Textbox(label="Initialization Status", lines=3)
init_btn.click(
initialize_pylate,
inputs=model_choice,
outputs=init_status
)
with gr.Tab("πŸ“„ Document Upload"):
gr.Markdown("### Upload and Process Documents")
with gr.Row():
with gr.Column():
file_upload = gr.File(
file_count="multiple",
file_types=[".pdf", ".docx", ".txt"],
label="Upload Documents (PDF, DOCX, TXT)"
)
with gr.Row():
chunk_size = gr.Slider(
minimum=500,
maximum=3000,
value=1000,
step=100,
label="Chunk Size (characters)"
)
overlap = gr.Slider(
minimum=0,
maximum=500,
value=100,
step=50,
label="Chunk Overlap (characters)"
)
process_btn = gr.Button(
"Process Documents", variant="primary")
with gr.Column():
process_status = gr.Textbox(
label="Processing Status",
lines=10,
max_lines=15
)
process_btn.click(
process_documents,
inputs=[file_upload, chunk_size, overlap],
outputs=process_status
)
with gr.Tab("πŸ” Search"):
gr.Markdown("### Search Your Documents")
with gr.Row():
with gr.Column():
search_query = gr.Textbox(
label="Search Query",
placeholder="Enter your search query...",
lines=2
)
with gr.Row():
num_results = gr.Slider(
minimum=1,
maximum=20,
value=5,
step=1,
label="Number of Results"
)
show_chunks = gr.Checkbox(
value=True,
label="Show Text Chunks"
)
search_btn = gr.Button("Search", variant="primary")
with gr.Column():
search_results = gr.Textbox(
label="Search Results",
lines=15,
max_lines=20
)
search_btn.click(
search_documents,
inputs=[search_query, num_results, show_chunks],
outputs=search_results
)
with gr.Tab("ℹ️ Info"):
gr.Markdown("""
### About This System
**PyLate Document Search** is a semantic search system that uses:
- **PyLate**: A flexible library for ColBERT models
- **ColBERT**: Late interaction retrieval for high-quality search
- **ZeroGPU**: Hugging Face's free GPU infrastructure
#### Features:
- πŸ“„ Multi-format document support (PDF, DOCX, TXT)
- βœ‚οΈ Intelligent text chunking with overlap
- 🧠 Semantic search using ColBERT embeddings
- πŸ’Ύ Metadata tracking for result context
- ⚑ GPU-accelerated processing
#### PDF Processing:
- Uses PyMuPDF (fitz) for reliable text extraction
- Falls back to Unstructured for complex PDFs
- No dependency on PyPDF2
#### Usage Tips:
1. Initialize the system first (required)
2. Upload your documents and process them
3. Use natural language queries for best results
4. Adjust chunk size based on your document types
Built with ❀️ using PyLate and Gradio
""")
return demo
# ===== MAIN =====
if __name__ == "__main__":
demo = create_interface()
demo.launch(
share=False,
server_name="0.0.0.0",
server_port=7860
)