fullstack commited on
Commit
8c651f8
Β·
1 Parent(s): a3706b1
Files changed (1) hide show
  1. app.py +422 -436
app.py CHANGED
@@ -8,16 +8,13 @@ import json
8
  import hashlib
9
  from pathlib import Path
10
  from typing import List, Dict, Any, Tuple
11
- import PyPDF2
12
  import docx
13
  import fitz # pymupdf
14
  from unstructured.partition.auto import partition
15
 
16
-
17
  os.environ["TRITON_CACHE_DIR"] = "/tmp/triton_cache"
18
  os.environ["TORCH_COMPILE_DISABLE"] = "1"
19
 
20
-
21
  # PyLate imports
22
  from pylate import models, indexes, retrieve
23
 
@@ -29,495 +26,484 @@ metadata_db = None
29
 
30
  # ===== DOCUMENT PROCESSING FUNCTIONS =====
31
 
32
-
33
  def extract_text_from_pdf(file_path: str) -> str:
34
- """Extract text from PDF file."""
35
- text = ""
36
- try:
37
- # Try PyMuPDF first (better for complex PDFs)
38
- doc = fitz.open(file_path)
39
- for page in doc:
40
- text += page.get_text() + "\n"
41
- doc.close()
42
- except:
43
- # Fallback to PyPDF2
44
- try:
45
- with open(file_path, 'rb') as file:
46
- pdf_reader = PyPDF2.PdfReader(file)
47
- for page in pdf_reader.pages:
48
- text += page.extract_text() + "\n"
49
- except:
50
- # Last resort: unstructured
51
- try:
52
- elements = partition(filename=file_path)
53
- text = "\n".join([str(element) for element in elements])
54
- except:
55
- text = "Error: Could not extract text from PDF"
56
-
57
- return text.strip()
58
-
59
 
60
  def extract_text_from_docx(file_path: str) -> str:
61
- """Extract text from DOCX file."""
62
- try:
63
- doc = docx.Document(file_path)
64
- text = ""
65
- for paragraph in doc.paragraphs:
66
- text += paragraph.text + "\n"
67
- return text.strip()
68
- except:
69
- return "Error: Could not extract text from DOCX"
70
-
71
 
72
  def extract_text_from_txt(file_path: str) -> str:
73
- """Extract text from TXT file."""
74
- try:
75
- with open(file_path, 'r', encoding='utf-8') as file:
76
- return file.read().strip()
77
- except:
78
- try:
79
- with open(file_path, 'r', encoding='latin1') as file:
80
- return file.read().strip()
81
- except:
82
- return "Error: Could not read text file"
83
-
84
 
85
  def chunk_text(text: str, chunk_size: int = 1000, overlap: int = 100) -> List[Dict[str, Any]]:
86
- """Chunk text with overlap and return metadata."""
87
- chunks = []
88
- start = 0
89
- chunk_index = 0
90
-
91
- while start < len(text):
92
- end = start + chunk_size
93
- chunk_text = text[start:end]
94
-
95
- # Try to break at sentence boundary
96
- if end < len(text):
97
- last_period = chunk_text.rfind('.')
98
- last_newline = chunk_text.rfind('\n')
99
- break_point = max(last_period, last_newline)
100
-
101
- if break_point > chunk_size * 0.7:
102
- chunk_text = chunk_text[:break_point + 1]
103
- end = start + break_point + 1
104
-
105
- if chunk_text.strip():
106
- chunks.append({
107
- 'text': chunk_text.strip(),
108
- 'start': start,
109
- 'end': end,
110
- 'index': chunk_index,
111
- 'length': len(chunk_text.strip())
112
- })
113
- chunk_index += 1
114
-
115
- start = max(start + 1, end - overlap)
116
-
117
- return chunks
118
 
119
  # ===== METADATA DATABASE =====
120
 
121
-
122
  def init_metadata_db():
123
- """Initialize SQLite database for metadata."""
124
- global metadata_db
125
-
126
- db_path = "metadata.db"
127
- metadata_db = sqlite3.connect(db_path, check_same_thread=False)
128
-
129
- metadata_db.execute("""
130
- CREATE TABLE IF NOT EXISTS documents (
131
- doc_id TEXT PRIMARY KEY,
132
- filename TEXT NOT NULL,
133
- file_hash TEXT NOT NULL,
134
- original_text TEXT NOT NULL,
135
- chunk_index INTEGER NOT NULL,
136
- total_chunks INTEGER NOT NULL,
137
- chunk_start INTEGER NOT NULL,
138
- chunk_end INTEGER NOT NULL,
139
- chunk_size INTEGER NOT NULL,
140
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
141
- )
142
- """)
143
-
144
- metadata_db.execute("""
145
- CREATE INDEX IF NOT EXISTS idx_filename ON documents(filename);
146
- """)
147
-
148
- metadata_db.commit()
149
-
150
 
151
  def add_document_metadata(doc_id: str, filename: str, file_hash: str,
152
- original_text: str, chunk_info: Dict[str, Any], total_chunks: int):
153
- """Add document metadata to database."""
154
- global metadata_db
155
-
156
- metadata_db.execute("""
157
- INSERT OR REPLACE INTO documents
158
- (doc_id, filename, file_hash, original_text, chunk_index, total_chunks,
159
- chunk_start, chunk_end, chunk_size)
160
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
161
- """, (
162
- doc_id, filename, file_hash, original_text,
163
- chunk_info['index'], total_chunks,
164
- chunk_info['start'], chunk_info['end'], chunk_info['length']
165
- ))
166
- metadata_db.commit()
167
-
168
 
169
  def get_document_metadata(doc_id: str) -> Dict[str, Any]:
170
- """Get document metadata by ID."""
171
- global metadata_db
172
 
173
- cursor = metadata_db.execute(
174
- "SELECT * FROM documents WHERE doc_id = ?", (doc_id,)
175
- )
176
- row = cursor.fetchone()
177
 
178
- if row:
179
- columns = [desc[0] for desc in cursor.description]
180
- return dict(zip(columns, row))
181
- return {}
182
 
183
  # ===== PYLATE INITIALIZATION =====
184
 
185
-
186
  @spaces.GPU
187
- def initialize_pylate(model_name: str = "lightonai/GTE-ModernColBERT-v1") -> str:
188
- """Initialize PyLate components on GPU."""
189
- global model, index, retriever
190
 
191
- try:
192
- # Initialize metadata database
193
- init_metadata_db()
194
 
195
- # Load ColBERT model
196
- model = models.ColBERT(model_name_or_path=model_name)
197
 
198
- # Move to GPU if available
199
- if torch.cuda.is_available():
200
- model = model.to('cuda')
201
 
202
- # Initialize PLAID index with CPU fallback for k-means
203
- index = indexes.PLAID(
204
- index_folder="./pylate_index",
205
- index_name="documents",
206
- override=True,
207
- kmeans_niters=1, # Reduce k-means iterations
208
- nbits=1 # Reduce quantization bits
209
- )
210
 
211
- # Initialize retriever
212
- retriever = retrieve.ColBERT(index=index)
213
 
214
- return f"βœ… PyLate initialized successfully!\nModel: {model_name}\nDevice: {'GPU' if torch.cuda.is_available() else 'CPU'}"
215
 
216
- except Exception as e:
217
- return f"❌ Error initializing PyLate: {str(e)}"
218
 
219
  # ===== DOCUMENT PROCESSING =====
220
 
221
-
222
  @spaces.GPU
223
  def process_documents(files, chunk_size: int = 1000, overlap: int = 100) -> str:
224
- """Process uploaded documents and add to index."""
225
- global model, index, metadata_db
226
-
227
- if not model or not index:
228
- return "❌ Please initialize PyLate first!"
229
-
230
- if not files:
231
- return "❌ No files uploaded!"
232
-
233
- try:
234
- all_documents = []
235
- all_doc_ids = []
236
- processed_files = []
237
-
238
- for file in files:
239
- # Get file info
240
- filename = Path(file.name).name
241
- file_path = file.name
242
-
243
- # Calculate file hash
244
- with open(file_path, 'rb') as f:
245
- file_hash = hashlib.md5(f.read()).hexdigest()
246
-
247
- # Extract text based on file type
248
- if filename.lower().endswith('.pdf'):
249
- text = extract_text_from_pdf(file_path)
250
- elif filename.lower().endswith('.docx'):
251
- text = extract_text_from_docx(file_path)
252
- elif filename.lower().endswith('.txt'):
253
- text = extract_text_from_txt(file_path)
254
- else:
255
- continue
256
-
257
- if not text or text.startswith("Error:"):
258
- continue
259
-
260
- # Chunk the text
261
- chunks = chunk_text(text, chunk_size, overlap)
262
-
263
- # Process each chunk
264
- for chunk in chunks:
265
- doc_id = f"{filename}_chunk_{chunk['index']}"
266
- all_documents.append(chunk['text'])
267
- all_doc_ids.append(doc_id)
268
-
269
- # Store metadata
270
- add_document_metadata(
271
- doc_id=doc_id,
272
- filename=filename,
273
- file_hash=file_hash,
274
- original_text=chunk['text'],
275
- chunk_info=chunk,
276
- total_chunks=len(chunks)
277
- )
278
-
279
- processed_files.append(f"{filename}: {len(chunks)} chunks")
280
-
281
- if not all_documents:
282
- return "❌ No text could be extracted from uploaded files!"
283
-
284
- # Encode documents with PyLate
285
- document_embeddings = model.encode(
286
- all_documents,
287
- batch_size=16, # Smaller batch for ZeroGPU
288
- is_query=False,
289
- show_progress_bar=True
290
- )
291
-
292
- # Add to PLAID index
293
- index.add_documents(
294
- documents_ids=all_doc_ids,
295
- documents_embeddings=document_embeddings
296
- )
297
-
298
- result = f"βœ… Successfully processed {len(files)} files:\n"
299
- result += f"πŸ“„ Total chunks: {len(all_documents)}\n"
300
- result += f"πŸ” Indexed documents:\n"
301
- for file_info in processed_files:
302
- result += f" β€’ {file_info}\n"
303
-
304
- return result
305
-
306
- except Exception as e:
307
- return f"❌ Error processing documents: {str(e)}"
 
308
 
309
  # ===== SEARCH FUNCTION =====
310
 
311
-
312
  @spaces.GPU
313
  def search_documents(query: str, k: int = 5, show_chunks: bool = True) -> str:
314
- """Search documents using PyLate."""
315
- global model, retriever, metadata_db
316
 
317
- if not model or not retriever:
318
- return "❌ Please initialize PyLate and process documents first!"
319
 
320
- if not query.strip():
321
- return "❌ Please enter a search query!"
322
 
323
- try:
324
- # Encode query
325
- query_embedding = model.encode([query], is_query=True)
326
 
327
- # Search
328
- results = retriever.retrieve(query_embedding, k=k)[0]
329
 
330
- if not results:
331
- return "πŸ” No results found for your query."
332
 
333
- # Format results with metadata
334
- formatted_results = [f"πŸ” **Search Results for:** '{query}'\n"]
335
 
336
- for i, result in enumerate(results):
337
- doc_id = result['id']
338
- score = result['score']
339
 
340
- # Get metadata
341
- metadata = get_document_metadata(doc_id)
342
 
343
- formatted_results.append(f"## Result {i+1} (Score: {score:.2f})")
344
- formatted_results.append(
345
- f"**File:** {metadata.get('filename', 'Unknown')}")
346
- formatted_results.append(
347
- f"**Chunk:** {metadata.get('chunk_index', 0) + 1}/{metadata.get('total_chunks', 1)}")
348
 
349
- if show_chunks:
350
- text = metadata.get('original_text', '')
351
- preview = text[:300] + "..." if len(text) > 300 else text
352
- formatted_results.append(f"**Text:** {preview}")
353
 
354
- formatted_results.append("---")
355
 
356
- return "\n".join(formatted_results)
357
 
358
- except Exception as e:
359
- return f"❌ Error searching: {str(e)}"
360
 
361
  # ===== GRADIO INTERFACE =====
362
 
363
-
364
  def create_interface():
365
- """Create the Gradio interface."""
366
-
367
- with gr.Blocks(title="PyLate Document Search", theme=gr.themes.Soft()) as demo:
368
- gr.Markdown("""
369
- # πŸ” PyLate Document Search
370
- ### Powered by ColBERT and ZeroGPU H100
371
-
372
- Upload documents, process them with PyLate, and perform semantic search!
373
- """)
374
-
375
- with gr.Tab("πŸš€ Setup"):
376
- gr.Markdown("### Initialize PyLate System")
377
-
378
- model_choice = gr.Dropdown(
379
- choices=[
380
- # "lightonai/GTE-ModernColBERT-v1",
381
- "colbert-ir/colbertv2.0",
382
- "sentence-transformers/all-MiniLM-L6-v2"
383
- ],
384
- value="lightonai/GTE-ModernColBERT-v1",
385
- label="Select Model"
386
- )
387
-
388
- init_btn = gr.Button("Initialize PyLate", variant="primary")
389
- init_status = gr.Textbox(label="Initialization Status", lines=3)
390
-
391
- init_btn.click(
392
- initialize_pylate,
393
- inputs=model_choice,
394
- outputs=init_status
395
- )
396
-
397
- with gr.Tab("πŸ“„ Document Upload"):
398
- gr.Markdown("### Upload and Process Documents")
399
-
400
- with gr.Row():
401
- with gr.Column():
402
- file_upload = gr.File(
403
- file_count="multiple",
404
- file_types=[".pdf", ".docx", ".txt"],
405
- label="Upload Documents (PDF, DOCX, TXT)"
406
- )
407
-
408
- with gr.Row():
409
- chunk_size = gr.Slider(
410
- minimum=500,
411
- maximum=3000,
412
- value=1000,
413
- step=100,
414
- label="Chunk Size (characters)"
415
- )
416
-
417
- overlap = gr.Slider(
418
- minimum=0,
419
- maximum=500,
420
- value=100,
421
- step=50,
422
- label="Chunk Overlap (characters)"
423
- )
424
-
425
- process_btn = gr.Button(
426
- "Process Documents", variant="primary")
427
-
428
- with gr.Column():
429
- process_status = gr.Textbox(
430
- label="Processing Status",
431
- lines=10,
432
- max_lines=15
433
- )
434
-
435
- process_btn.click(
436
- process_documents,
437
- inputs=[file_upload, chunk_size, overlap],
438
- outputs=process_status
439
- )
440
-
441
- with gr.Tab("πŸ” Search"):
442
- gr.Markdown("### Search Your Documents")
443
-
444
- with gr.Row():
445
- with gr.Column():
446
- search_query = gr.Textbox(
447
- label="Search Query",
448
- placeholder="Enter your search query...",
449
- lines=2
450
- )
451
-
452
- with gr.Row():
453
- num_results = gr.Slider(
454
- minimum=1,
455
- maximum=20,
456
- value=5,
457
- step=1,
458
- label="Number of Results"
459
- )
460
-
461
- show_chunks = gr.Checkbox(
462
- value=True,
463
- label="Show Text Chunks"
464
- )
465
-
466
- search_btn = gr.Button("Search", variant="primary")
467
-
468
- with gr.Column():
469
- search_results = gr.Textbox(
470
- label="Search Results",
471
- lines=15,
472
- max_lines=20
473
- )
474
-
475
- search_btn.click(
476
- search_documents,
477
- inputs=[search_query, num_results, show_chunks],
478
- outputs=search_results
479
- )
480
-
481
- with gr.Tab("ℹ️ Info"):
482
- gr.Markdown("""
483
- ### About This System
484
-
485
- **PyLate Document Search** is a semantic search system that uses:
486
-
487
- - **PyLate**: A flexible library for ColBERT models
488
- - **ColBERT**: Late interaction retrieval for high-quality search
489
- - **ZeroGPU**: Hugging Face's free H100 GPU infrastructure
490
-
491
- #### Features:
492
- - πŸ“„ Multi-format document support (PDF, DOCX, TXT)
493
- - βœ‚οΈ Intelligent text chunking with overlap
494
- - 🧠 Semantic search using ColBERT embeddings
495
- - πŸ’Ύ Metadata tracking for result context
496
- - ⚑ GPU-accelerated processing
497
-
498
- #### Usage Tips:
499
- 1. Initialize the system first (required)
500
- 2. Upload your documents and process them
501
- 3. Use natural language queries for best results
502
- 4. Adjust chunk size based on your document types
503
-
504
- #### Model Information:
505
- - **GTE-ModernColBERT**: Latest high-performance model
506
- - **ColBERTv2**: Original Stanford implementation
507
- - **MiniLM**: Faster, smaller model for quick testing
508
-
509
- Built with ❀️ using PyLate and Gradio
510
- """)
511
-
512
- return demo
 
513
 
514
  # ===== MAIN =====
515
 
516
-
517
  if __name__ == "__main__":
518
- demo = create_interface()
519
- demo.launch(
520
- share=False,
521
- server_name="0.0.0.0",
522
- server_port=7860
523
- )
 
8
  import hashlib
9
  from pathlib import Path
10
  from typing import List, Dict, Any, Tuple
 
11
  import docx
12
  import fitz # pymupdf
13
  from unstructured.partition.auto import partition
14
 
 
15
  os.environ["TRITON_CACHE_DIR"] = "/tmp/triton_cache"
16
  os.environ["TORCH_COMPILE_DISABLE"] = "1"
17
 
 
18
  # PyLate imports
19
  from pylate import models, indexes, retrieve
20
 
 
26
 
27
  # ===== DOCUMENT PROCESSING FUNCTIONS =====
28
 
 
29
  def extract_text_from_pdf(file_path: str) -> str:
30
+ """Extract text from PDF file using PyMuPDF and unstructured as fallback."""
31
+ text = ""
32
+ try:
33
+ # Use PyMuPDF (fitz) - more reliable than PyPDF2
34
+ doc = fitz.open(file_path)
35
+ for page in doc:
36
+ text += page.get_text() + "\n"
37
+ doc.close()
38
+
39
+ # If no text extracted, try unstructured
40
+ if not text.strip():
41
+ elements = partition(filename=file_path)
42
+ text = "\n".join([str(element) for element in elements])
43
+
44
+ except Exception as e:
45
+ # Final fallback to unstructured
46
+ try:
47
+ elements = partition(filename=file_path)
48
+ text = "\n".join([str(element) for element in elements])
49
+ except:
50
+ text = f"Error: Could not extract text from PDF: {str(e)}"
51
+
52
+ return text.strip()
 
 
53
 
54
  def extract_text_from_docx(file_path: str) -> str:
55
+ """Extract text from DOCX file."""
56
+ try:
57
+ doc = docx.Document(file_path)
58
+ text = ""
59
+ for paragraph in doc.paragraphs:
60
+ text += paragraph.text + "\n"
61
+ return text.strip()
62
+ except Exception as e:
63
+ return f"Error: Could not extract text from DOCX: {str(e)}"
 
64
 
65
  def extract_text_from_txt(file_path: str) -> str:
66
+ """Extract text from TXT file."""
67
+ try:
68
+ with open(file_path, 'r', encoding='utf-8') as file:
69
+ return file.read().strip()
70
+ except:
71
+ try:
72
+ with open(file_path, 'r', encoding='latin1') as file:
73
+ return file.read().strip()
74
+ except Exception as e:
75
+ return f"Error: Could not read text file: {str(e)}"
 
76
 
77
  def chunk_text(text: str, chunk_size: int = 1000, overlap: int = 100) -> List[Dict[str, Any]]:
78
+ """Chunk text with overlap and return metadata."""
79
+ chunks = []
80
+ start = 0
81
+ chunk_index = 0
82
+
83
+ while start < len(text):
84
+ end = start + chunk_size
85
+ chunk_text = text[start:end]
86
+
87
+ # Try to break at sentence boundary
88
+ if end < len(text):
89
+ last_period = chunk_text.rfind('.')
90
+ last_newline = chunk_text.rfind('\n')
91
+ break_point = max(last_period, last_newline)
92
+
93
+ if break_point > chunk_size * 0.7:
94
+ chunk_text = chunk_text[:break_point + 1]
95
+ end = start + break_point + 1
96
+
97
+ if chunk_text.strip():
98
+ chunks.append({
99
+ 'text': chunk_text.strip(),
100
+ 'start': start,
101
+ 'end': end,
102
+ 'index': chunk_index,
103
+ 'length': len(chunk_text.strip())
104
+ })
105
+ chunk_index += 1
106
+
107
+ start = max(start + 1, end - overlap)
108
+
109
+ return chunks
110
 
111
  # ===== METADATA DATABASE =====
112
 
 
113
  def init_metadata_db():
114
+ """Initialize SQLite database for metadata."""
115
+ global metadata_db
116
+
117
+ db_path = "metadata.db"
118
+ metadata_db = sqlite3.connect(db_path, check_same_thread=False)
119
+
120
+ metadata_db.execute("""
121
+ CREATE TABLE IF NOT EXISTS documents (
122
+ doc_id TEXT PRIMARY KEY,
123
+ filename TEXT NOT NULL,
124
+ file_hash TEXT NOT NULL,
125
+ original_text TEXT NOT NULL,
126
+ chunk_index INTEGER NOT NULL,
127
+ total_chunks INTEGER NOT NULL,
128
+ chunk_start INTEGER NOT NULL,
129
+ chunk_end INTEGER NOT NULL,
130
+ chunk_size INTEGER NOT NULL,
131
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
132
+ )
133
+ """)
134
+
135
+ metadata_db.execute("""
136
+ CREATE INDEX IF NOT EXISTS idx_filename ON documents(filename);
137
+ """)
138
+
139
+ metadata_db.commit()
 
140
 
141
  def add_document_metadata(doc_id: str, filename: str, file_hash: str,
142
+ original_text: str, chunk_info: Dict[str, Any], total_chunks: int):
143
+ """Add document metadata to database."""
144
+ global metadata_db
145
+
146
+ metadata_db.execute("""
147
+ INSERT OR REPLACE INTO documents
148
+ (doc_id, filename, file_hash, original_text, chunk_index, total_chunks,
149
+ chunk_start, chunk_end, chunk_size)
150
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
151
+ """, (
152
+ doc_id, filename, file_hash, original_text,
153
+ chunk_info['index'], total_chunks,
154
+ chunk_info['start'], chunk_info['end'], chunk_info['length']
155
+ ))
156
+ metadata_db.commit()
 
157
 
158
  def get_document_metadata(doc_id: str) -> Dict[str, Any]:
159
+ """Get document metadata by ID."""
160
+ global metadata_db
161
 
162
+ cursor = metadata_db.execute(
163
+ "SELECT * FROM documents WHERE doc_id = ?", (doc_id,)
164
+ )
165
+ row = cursor.fetchone()
166
 
167
+ if row:
168
+ columns = [desc[0] for desc in cursor.description]
169
+ return dict(zip(columns, row))
170
+ return {}
171
 
172
  # ===== PYLATE INITIALIZATION =====
173
 
 
174
  @spaces.GPU
175
+ def initialize_pylate(model_name: str = "colbert-ir/colbertv2.0") -> str:
176
+ """Initialize PyLate components on GPU."""
177
+ global model, index, retriever
178
 
179
+ try:
180
+ # Initialize metadata database
181
+ init_metadata_db()
182
 
183
+ # Load ColBERT model
184
+ model = models.ColBERT(model_name_or_path=model_name)
185
 
186
+ # Move to GPU if available
187
+ if torch.cuda.is_available():
188
+ model = model.to('cuda')
189
 
190
+ # Initialize PLAID index with CPU fallback for k-means
191
+ index = indexes.PLAID(
192
+ index_folder="./pylate_index",
193
+ index_name="documents",
194
+ override=True,
195
+ kmeans_niters=1, # Reduce k-means iterations
196
+ nbits=1 # Reduce quantization bits
197
+ )
198
 
199
+ # Initialize retriever
200
+ retriever = retrieve.ColBERT(index=index)
201
 
202
+ return f"βœ… PyLate initialized successfully!\nModel: {model_name}\nDevice: {'GPU' if torch.cuda.is_available() else 'CPU'}"
203
 
204
+ except Exception as e:
205
+ return f"❌ Error initializing PyLate: {str(e)}"
206
 
207
  # ===== DOCUMENT PROCESSING =====
208
 
 
209
  @spaces.GPU
210
  def process_documents(files, chunk_size: int = 1000, overlap: int = 100) -> str:
211
+ """Process uploaded documents and add to index."""
212
+ global model, index, metadata_db
213
+
214
+ if not model or not index:
215
+ return "❌ Please initialize PyLate first!"
216
+
217
+ if not files:
218
+ return "❌ No files uploaded!"
219
+
220
+ try:
221
+ all_documents = []
222
+ all_doc_ids = []
223
+ processed_files = []
224
+
225
+ for file in files:
226
+ # Get file info
227
+ filename = Path(file.name).name
228
+ file_path = file.name
229
+
230
+ # Calculate file hash
231
+ with open(file_path, 'rb') as f:
232
+ file_hash = hashlib.md5(f.read()).hexdigest()
233
+
234
+ # Extract text based on file type
235
+ if filename.lower().endswith('.pdf'):
236
+ text = extract_text_from_pdf(file_path)
237
+ elif filename.lower().endswith('.docx'):
238
+ text = extract_text_from_docx(file_path)
239
+ elif filename.lower().endswith('.txt'):
240
+ text = extract_text_from_txt(file_path)
241
+ else:
242
+ continue
243
+
244
+ if not text or text.startswith("Error:"):
245
+ processed_files.append(f"{filename}: Failed to extract text")
246
+ continue
247
+
248
+ # Chunk the text
249
+ chunks = chunk_text(text, chunk_size, overlap)
250
+
251
+ # Process each chunk
252
+ for chunk in chunks:
253
+ doc_id = f"{filename}_chunk_{chunk['index']}"
254
+ all_documents.append(chunk['text'])
255
+ all_doc_ids.append(doc_id)
256
+
257
+ # Store metadata
258
+ add_document_metadata(
259
+ doc_id=doc_id,
260
+ filename=filename,
261
+ file_hash=file_hash,
262
+ original_text=chunk['text'],
263
+ chunk_info=chunk,
264
+ total_chunks=len(chunks)
265
+ )
266
+
267
+ processed_files.append(f"{filename}: {len(chunks)} chunks")
268
+
269
+ if not all_documents:
270
+ return "❌ No text could be extracted from uploaded files!"
271
+
272
+ # Encode documents with PyLate
273
+ document_embeddings = model.encode(
274
+ all_documents,
275
+ batch_size=16, # Smaller batch for ZeroGPU
276
+ is_query=False,
277
+ show_progress_bar=True
278
+ )
279
+
280
+ # Add to PLAID index
281
+ index.add_documents(
282
+ documents_ids=all_doc_ids,
283
+ documents_embeddings=document_embeddings
284
+ )
285
+
286
+ result = f"βœ… Successfully processed {len(files)} files:\n"
287
+ result += f"πŸ“„ Total chunks: {len(all_documents)}\n"
288
+ result += f"πŸ” Indexed documents:\n"
289
+ for file_info in processed_files:
290
+ result += f" β€’ {file_info}\n"
291
+
292
+ return result
293
+
294
+ except Exception as e:
295
+ return f"❌ Error processing documents: {str(e)}"
296
 
297
  # ===== SEARCH FUNCTION =====
298
 
 
299
  @spaces.GPU
300
  def search_documents(query: str, k: int = 5, show_chunks: bool = True) -> str:
301
+ """Search documents using PyLate."""
302
+ global model, retriever, metadata_db
303
 
304
+ if not model or not retriever:
305
+ return "❌ Please initialize PyLate and process documents first!"
306
 
307
+ if not query.strip():
308
+ return "❌ Please enter a search query!"
309
 
310
+ try:
311
+ # Encode query
312
+ query_embedding = model.encode([query], is_query=True)
313
 
314
+ # Search
315
+ results = retriever.retrieve(query_embedding, k=k)[0]
316
 
317
+ if not results:
318
+ return "πŸ” No results found for your query."
319
 
320
+ # Format results with metadata
321
+ formatted_results = [f"πŸ” **Search Results for:** '{query}'\n"]
322
 
323
+ for i, result in enumerate(results):
324
+ doc_id = result['id']
325
+ score = result['score']
326
 
327
+ # Get metadata
328
+ metadata = get_document_metadata(doc_id)
329
 
330
+ formatted_results.append(f"## Result {i+1} (Score: {score:.2f})")
331
+ formatted_results.append(
332
+ f"**File:** {metadata.get('filename', 'Unknown')}")
333
+ formatted_results.append(
334
+ f"**Chunk:** {metadata.get('chunk_index', 0) + 1}/{metadata.get('total_chunks', 1)}")
335
 
336
+ if show_chunks:
337
+ text = metadata.get('original_text', '')
338
+ preview = text[:300] + "..." if len(text) > 300 else text
339
+ formatted_results.append(f"**Text:** {preview}")
340
 
341
+ formatted_results.append("---")
342
 
343
+ return "\n".join(formatted_results)
344
 
345
+ except Exception as e:
346
+ return f"❌ Error searching: {str(e)}"
347
 
348
  # ===== GRADIO INTERFACE =====
349
 
 
350
  def create_interface():
351
+ """Create the Gradio interface."""
352
+
353
+ with gr.Blocks(title="PyLate Document Search", theme=gr.themes.Soft()) as demo:
354
+ gr.Markdown("""
355
+ # πŸ” PyLate Document Search
356
+ ### Powered by ColBERT and ZeroGPU
357
+
358
+ Upload documents, process them with PyLate, and perform semantic search!
359
+
360
+ **Note:** Using PyMuPDF and Unstructured for robust PDF text extraction.
361
+ """)
362
+
363
+ with gr.Tab("πŸš€ Setup"):
364
+ gr.Markdown("### Initialize PyLate System")
365
+
366
+ model_choice = gr.Dropdown(
367
+ choices=[
368
+ "colbert-ir/colbertv2.0",
369
+ "sentence-transformers/all-MiniLM-L6-v2"
370
+ ],
371
+ value="colbert-ir/colbertv2.0",
372
+ label="Select Model"
373
+ )
374
+
375
+ init_btn = gr.Button("Initialize PyLate", variant="primary")
376
+ init_status = gr.Textbox(label="Initialization Status", lines=3)
377
+
378
+ init_btn.click(
379
+ initialize_pylate,
380
+ inputs=model_choice,
381
+ outputs=init_status
382
+ )
383
+
384
+ with gr.Tab("πŸ“„ Document Upload"):
385
+ gr.Markdown("### Upload and Process Documents")
386
+
387
+ with gr.Row():
388
+ with gr.Column():
389
+ file_upload = gr.File(
390
+ file_count="multiple",
391
+ file_types=[".pdf", ".docx", ".txt"],
392
+ label="Upload Documents (PDF, DOCX, TXT)"
393
+ )
394
+
395
+ with gr.Row():
396
+ chunk_size = gr.Slider(
397
+ minimum=500,
398
+ maximum=3000,
399
+ value=1000,
400
+ step=100,
401
+ label="Chunk Size (characters)"
402
+ )
403
+
404
+ overlap = gr.Slider(
405
+ minimum=0,
406
+ maximum=500,
407
+ value=100,
408
+ step=50,
409
+ label="Chunk Overlap (characters)"
410
+ )
411
+
412
+ process_btn = gr.Button(
413
+ "Process Documents", variant="primary")
414
+
415
+ with gr.Column():
416
+ process_status = gr.Textbox(
417
+ label="Processing Status",
418
+ lines=10,
419
+ max_lines=15
420
+ )
421
+
422
+ process_btn.click(
423
+ process_documents,
424
+ inputs=[file_upload, chunk_size, overlap],
425
+ outputs=process_status
426
+ )
427
+
428
+ with gr.Tab("πŸ” Search"):
429
+ gr.Markdown("### Search Your Documents")
430
+
431
+ with gr.Row():
432
+ with gr.Column():
433
+ search_query = gr.Textbox(
434
+ label="Search Query",
435
+ placeholder="Enter your search query...",
436
+ lines=2
437
+ )
438
+
439
+ with gr.Row():
440
+ num_results = gr.Slider(
441
+ minimum=1,
442
+ maximum=20,
443
+ value=5,
444
+ step=1,
445
+ label="Number of Results"
446
+ )
447
+
448
+ show_chunks = gr.Checkbox(
449
+ value=True,
450
+ label="Show Text Chunks"
451
+ )
452
+
453
+ search_btn = gr.Button("Search", variant="primary")
454
+
455
+ with gr.Column():
456
+ search_results = gr.Textbox(
457
+ label="Search Results",
458
+ lines=15,
459
+ max_lines=20
460
+ )
461
+
462
+ search_btn.click(
463
+ search_documents,
464
+ inputs=[search_query, num_results, show_chunks],
465
+ outputs=search_results
466
+ )
467
+
468
+ with gr.Tab("ℹ️ Info"):
469
+ gr.Markdown("""
470
+ ### About This System
471
+
472
+ **PyLate Document Search** is a semantic search system that uses:
473
+
474
+ - **PyLate**: A flexible library for ColBERT models
475
+ - **ColBERT**: Late interaction retrieval for high-quality search
476
+ - **ZeroGPU**: Hugging Face's free GPU infrastructure
477
+
478
+ #### Features:
479
+ - πŸ“„ Multi-format document support (PDF, DOCX, TXT)
480
+ - βœ‚οΈ Intelligent text chunking with overlap
481
+ - 🧠 Semantic search using ColBERT embeddings
482
+ - πŸ’Ύ Metadata tracking for result context
483
+ - ⚑ GPU-accelerated processing
484
+
485
+ #### PDF Processing:
486
+ - Uses PyMuPDF (fitz) for reliable text extraction
487
+ - Falls back to Unstructured for complex PDFs
488
+ - No dependency on PyPDF2
489
+
490
+ #### Usage Tips:
491
+ 1. Initialize the system first (required)
492
+ 2. Upload your documents and process them
493
+ 3. Use natural language queries for best results
494
+ 4. Adjust chunk size based on your document types
495
+
496
+ Built with ❀️ using PyLate and Gradio
497
+ """)
498
+
499
+ return demo
500
 
501
  # ===== MAIN =====
502
 
 
503
  if __name__ == "__main__":
504
+ demo = create_interface()
505
+ demo.launch(
506
+ share=False,
507
+ server_name="0.0.0.0",
508
+ server_port=7860
509
+ )