Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import torch | |
import os | |
import tempfile | |
import sqlite3 | |
import json | |
import hashlib | |
from pathlib import Path | |
from typing import List, Dict, Any, Tuple | |
import PyPDF2 | |
import docx | |
import fitz # pymupdf | |
from unstructured.partition.auto import partition | |
os.environ["TRITON_CACHE_DIR"] = "/tmp/triton_cache" | |
os.environ["TORCH_COMPILE_DISABLE"] = "1" | |
# PyLate imports | |
from pylate import models, indexes, retrieve | |
# Global variables for PyLate components | |
model = None | |
index = None | |
retriever = None | |
metadata_db = None | |
# ===== DOCUMENT PROCESSING FUNCTIONS ===== | |
def extract_text_from_pdf(file_path: str) -> str: | |
"""Extract text from PDF file.""" | |
text = "" | |
try: | |
# Try PyMuPDF first (better for complex PDFs) | |
doc = fitz.open(file_path) | |
for page in doc: | |
text += page.get_text() + "\n" | |
doc.close() | |
except: | |
# Fallback to PyPDF2 | |
try: | |
with open(file_path, 'rb') as file: | |
pdf_reader = PyPDF2.PdfReader(file) | |
for page in pdf_reader.pages: | |
text += page.extract_text() + "\n" | |
except: | |
# Last resort: unstructured | |
try: | |
elements = partition(filename=file_path) | |
text = "\n".join([str(element) for element in elements]) | |
except: | |
text = "Error: Could not extract text from PDF" | |
return text.strip() | |
def extract_text_from_docx(file_path: str) -> str: | |
"""Extract text from DOCX file.""" | |
try: | |
doc = docx.Document(file_path) | |
text = "" | |
for paragraph in doc.paragraphs: | |
text += paragraph.text + "\n" | |
return text.strip() | |
except: | |
return "Error: Could not extract text from DOCX" | |
def extract_text_from_txt(file_path: str) -> str: | |
"""Extract text from TXT file.""" | |
try: | |
with open(file_path, 'r', encoding='utf-8') as file: | |
return file.read().strip() | |
except: | |
try: | |
with open(file_path, 'r', encoding='latin1') as file: | |
return file.read().strip() | |
except: | |
return "Error: Could not read text file" | |
def chunk_text(text: str, chunk_size: int = 1000, overlap: int = 100) -> List[Dict[str, Any]]: | |
"""Chunk text with overlap and return metadata.""" | |
chunks = [] | |
start = 0 | |
chunk_index = 0 | |
while start < len(text): | |
end = start + chunk_size | |
chunk_text = text[start:end] | |
# Try to break at sentence boundary | |
if end < len(text): | |
last_period = chunk_text.rfind('.') | |
last_newline = chunk_text.rfind('\n') | |
break_point = max(last_period, last_newline) | |
if break_point > chunk_size * 0.7: | |
chunk_text = chunk_text[:break_point + 1] | |
end = start + break_point + 1 | |
if chunk_text.strip(): | |
chunks.append({ | |
'text': chunk_text.strip(), | |
'start': start, | |
'end': end, | |
'index': chunk_index, | |
'length': len(chunk_text.strip()) | |
}) | |
chunk_index += 1 | |
start = max(start + 1, end - overlap) | |
return chunks | |
# ===== METADATA DATABASE ===== | |
def init_metadata_db(): | |
"""Initialize SQLite database for metadata.""" | |
global metadata_db | |
db_path = "metadata.db" | |
metadata_db = sqlite3.connect(db_path, check_same_thread=False) | |
metadata_db.execute(""" | |
CREATE TABLE IF NOT EXISTS documents ( | |
doc_id TEXT PRIMARY KEY, | |
filename TEXT NOT NULL, | |
file_hash TEXT NOT NULL, | |
original_text TEXT NOT NULL, | |
chunk_index INTEGER NOT NULL, | |
total_chunks INTEGER NOT NULL, | |
chunk_start INTEGER NOT NULL, | |
chunk_end INTEGER NOT NULL, | |
chunk_size INTEGER NOT NULL, | |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
) | |
""") | |
metadata_db.execute(""" | |
CREATE INDEX IF NOT EXISTS idx_filename ON documents(filename); | |
""") | |
metadata_db.commit() | |
def add_document_metadata(doc_id: str, filename: str, file_hash: str, | |
original_text: str, chunk_info: Dict[str, Any], total_chunks: int): | |
"""Add document metadata to database.""" | |
global metadata_db | |
metadata_db.execute(""" | |
INSERT OR REPLACE INTO documents | |
(doc_id, filename, file_hash, original_text, chunk_index, total_chunks, | |
chunk_start, chunk_end, chunk_size) | |
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) | |
""", ( | |
doc_id, filename, file_hash, original_text, | |
chunk_info['index'], total_chunks, | |
chunk_info['start'], chunk_info['end'], chunk_info['length'] | |
)) | |
metadata_db.commit() | |
def get_document_metadata(doc_id: str) -> Dict[str, Any]: | |
"""Get document metadata by ID.""" | |
global metadata_db | |
cursor = metadata_db.execute( | |
"SELECT * FROM documents WHERE doc_id = ?", (doc_id,) | |
) | |
row = cursor.fetchone() | |
if row: | |
columns = [desc[0] for desc in cursor.description] | |
return dict(zip(columns, row)) | |
return {} | |
# ===== PYLATE INITIALIZATION ===== | |
def initialize_pylate(model_name: str = "lightonai/GTE-ModernColBERT-v1") -> str: | |
"""Initialize PyLate components on GPU.""" | |
global model, index, retriever | |
try: | |
# Initialize metadata database | |
init_metadata_db() | |
# Load ColBERT model | |
model = models.ColBERT(model_name_or_path=model_name) | |
# Move to GPU if available | |
if torch.cuda.is_available(): | |
model = model.to('cuda') | |
# Initialize PLAID index with CPU fallback for k-means | |
index = indexes.PLAID( | |
index_folder="./pylate_index", | |
index_name="documents", | |
override=True, | |
kmeans_niters=1, # Reduce k-means iterations | |
nbits=1 # Reduce quantization bits | |
) | |
# Initialize retriever | |
retriever = retrieve.ColBERT(index=index) | |
return f"โ PyLate initialized successfully!\nModel: {model_name}\nDevice: {'GPU' if torch.cuda.is_available() else 'CPU'}" | |
except Exception as e: | |
return f"โ Error initializing PyLate: {str(e)}" | |
# ===== DOCUMENT PROCESSING ===== | |
def process_documents(files, chunk_size: int = 1000, overlap: int = 100) -> str: | |
"""Process uploaded documents and add to index.""" | |
global model, index, metadata_db | |
if not model or not index: | |
return "โ Please initialize PyLate first!" | |
if not files: | |
return "โ No files uploaded!" | |
try: | |
all_documents = [] | |
all_doc_ids = [] | |
processed_files = [] | |
for file in files: | |
# Get file info | |
filename = Path(file.name).name | |
file_path = file.name | |
# Calculate file hash | |
with open(file_path, 'rb') as f: | |
file_hash = hashlib.md5(f.read()).hexdigest() | |
# Extract text based on file type | |
if filename.lower().endswith('.pdf'): | |
text = extract_text_from_pdf(file_path) | |
elif filename.lower().endswith('.docx'): | |
text = extract_text_from_docx(file_path) | |
elif filename.lower().endswith('.txt'): | |
text = extract_text_from_txt(file_path) | |
else: | |
continue | |
if not text or text.startswith("Error:"): | |
continue | |
# Chunk the text | |
chunks = chunk_text(text, chunk_size, overlap) | |
# Process each chunk | |
for chunk in chunks: | |
doc_id = f"{filename}_chunk_{chunk['index']}" | |
all_documents.append(chunk['text']) | |
all_doc_ids.append(doc_id) | |
# Store metadata | |
add_document_metadata( | |
doc_id=doc_id, | |
filename=filename, | |
file_hash=file_hash, | |
original_text=chunk['text'], | |
chunk_info=chunk, | |
total_chunks=len(chunks) | |
) | |
processed_files.append(f"{filename}: {len(chunks)} chunks") | |
if not all_documents: | |
return "โ No text could be extracted from uploaded files!" | |
# Encode documents with PyLate | |
document_embeddings = model.encode( | |
all_documents, | |
batch_size=16, # Smaller batch for ZeroGPU | |
is_query=False, | |
show_progress_bar=True | |
) | |
# Add to PLAID index | |
index.add_documents( | |
documents_ids=all_doc_ids, | |
documents_embeddings=document_embeddings | |
) | |
result = f"โ Successfully processed {len(files)} files:\n" | |
result += f"๐ Total chunks: {len(all_documents)}\n" | |
result += f"๐ Indexed documents:\n" | |
for file_info in processed_files: | |
result += f" โข {file_info}\n" | |
return result | |
except Exception as e: | |
return f"โ Error processing documents: {str(e)}" | |
# ===== SEARCH FUNCTION ===== | |
def search_documents(query: str, k: int = 5, show_chunks: bool = True) -> str: | |
"""Search documents using PyLate.""" | |
global model, retriever, metadata_db | |
if not model or not retriever: | |
return "โ Please initialize PyLate and process documents first!" | |
if not query.strip(): | |
return "โ Please enter a search query!" | |
try: | |
# Encode query | |
query_embedding = model.encode([query], is_query=True) | |
# Search | |
results = retriever.retrieve(query_embedding, k=k)[0] | |
if not results: | |
return "๐ No results found for your query." | |
# Format results with metadata | |
formatted_results = [f"๐ **Search Results for:** '{query}'\n"] | |
for i, result in enumerate(results): | |
doc_id = result['id'] | |
score = result['score'] | |
# Get metadata | |
metadata = get_document_metadata(doc_id) | |
formatted_results.append(f"## Result {i+1} (Score: {score:.2f})") | |
formatted_results.append( | |
f"**File:** {metadata.get('filename', 'Unknown')}") | |
formatted_results.append( | |
f"**Chunk:** {metadata.get('chunk_index', 0) + 1}/{metadata.get('total_chunks', 1)}") | |
if show_chunks: | |
text = metadata.get('original_text', '') | |
preview = text[:300] + "..." if len(text) > 300 else text | |
formatted_results.append(f"**Text:** {preview}") | |
formatted_results.append("---") | |
return "\n".join(formatted_results) | |
except Exception as e: | |
return f"โ Error searching: {str(e)}" | |
# ===== GRADIO INTERFACE ===== | |
def create_interface(): | |
"""Create the Gradio interface.""" | |
with gr.Blocks(title="PyLate Document Search", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# ๐ PyLate Document Search | |
### Powered by ColBERT and ZeroGPU H100 | |
Upload documents, process them with PyLate, and perform semantic search! | |
""") | |
with gr.Tab("๐ Setup"): | |
gr.Markdown("### Initialize PyLate System") | |
model_choice = gr.Dropdown( | |
choices=[ | |
# "lightonai/GTE-ModernColBERT-v1", | |
"colbert-ir/colbertv2.0", | |
"sentence-transformers/all-MiniLM-L6-v2" | |
], | |
value="lightonai/GTE-ModernColBERT-v1", | |
label="Select Model" | |
) | |
init_btn = gr.Button("Initialize PyLate", variant="primary") | |
init_status = gr.Textbox(label="Initialization Status", lines=3) | |
init_btn.click( | |
initialize_pylate, | |
inputs=model_choice, | |
outputs=init_status | |
) | |
with gr.Tab("๐ Document Upload"): | |
gr.Markdown("### Upload and Process Documents") | |
with gr.Row(): | |
with gr.Column(): | |
file_upload = gr.File( | |
file_count="multiple", | |
file_types=[".pdf", ".docx", ".txt"], | |
label="Upload Documents (PDF, DOCX, TXT)" | |
) | |
with gr.Row(): | |
chunk_size = gr.Slider( | |
minimum=500, | |
maximum=3000, | |
value=1000, | |
step=100, | |
label="Chunk Size (characters)" | |
) | |
overlap = gr.Slider( | |
minimum=0, | |
maximum=500, | |
value=100, | |
step=50, | |
label="Chunk Overlap (characters)" | |
) | |
process_btn = gr.Button( | |
"Process Documents", variant="primary") | |
with gr.Column(): | |
process_status = gr.Textbox( | |
label="Processing Status", | |
lines=10, | |
max_lines=15 | |
) | |
process_btn.click( | |
process_documents, | |
inputs=[file_upload, chunk_size, overlap], | |
outputs=process_status | |
) | |
with gr.Tab("๐ Search"): | |
gr.Markdown("### Search Your Documents") | |
with gr.Row(): | |
with gr.Column(): | |
search_query = gr.Textbox( | |
label="Search Query", | |
placeholder="Enter your search query...", | |
lines=2 | |
) | |
with gr.Row(): | |
num_results = gr.Slider( | |
minimum=1, | |
maximum=20, | |
value=5, | |
step=1, | |
label="Number of Results" | |
) | |
show_chunks = gr.Checkbox( | |
value=True, | |
label="Show Text Chunks" | |
) | |
search_btn = gr.Button("Search", variant="primary") | |
with gr.Column(): | |
search_results = gr.Textbox( | |
label="Search Results", | |
lines=15, | |
max_lines=20 | |
) | |
search_btn.click( | |
search_documents, | |
inputs=[search_query, num_results, show_chunks], | |
outputs=search_results | |
) | |
with gr.Tab("โน๏ธ Info"): | |
gr.Markdown(""" | |
### About This System | |
**PyLate Document Search** is a semantic search system that uses: | |
- **PyLate**: A flexible library for ColBERT models | |
- **ColBERT**: Late interaction retrieval for high-quality search | |
- **ZeroGPU**: Hugging Face's free H100 GPU infrastructure | |
#### Features: | |
- ๐ Multi-format document support (PDF, DOCX, TXT) | |
- โ๏ธ Intelligent text chunking with overlap | |
- ๐ง Semantic search using ColBERT embeddings | |
- ๐พ Metadata tracking for result context | |
- โก GPU-accelerated processing | |
#### Usage Tips: | |
1. Initialize the system first (required) | |
2. Upload your documents and process them | |
3. Use natural language queries for best results | |
4. Adjust chunk size based on your document types | |
#### Model Information: | |
- **GTE-ModernColBERT**: Latest high-performance model | |
- **ColBERTv2**: Original Stanford implementation | |
- **MiniLM**: Faster, smaller model for quick testing | |
Built with โค๏ธ using PyLate and Gradio | |
""") | |
return demo | |
# ===== MAIN ===== | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch( | |
share=False, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) | |