Spaces:
Running
Running
import logging | |
import os | |
from dataclasses import dataclass | |
from enum import Enum | |
from typing import Any, Dict, List, Optional | |
import mteb | |
from sqlitedict import SqliteDict | |
from pylate import indexes, models, retrieve | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
) | |
class IndexType(Enum): | |
"""Supported index types.""" | |
PREBUILT = "prebuilt" | |
LOCAL = "local" | |
class IndexConfig: | |
"""Configuration for a search index.""" | |
name: str | |
type: IndexType | |
path: str | |
description: Optional[str] = None | |
class MCPyLate: | |
"""Main server class that manages PyLate indexes and search operations.""" | |
def __init__(self, override: bool = False): | |
self.logger = logging.getLogger(__name__) | |
dataset_name = "leetcode" | |
model_name = "lightonai/Reason-ModernColBERT" | |
override = override or not os.path.exists( | |
f"indexes/{dataset_name}_{model_name.split('/')[-1]}" | |
) | |
self.model = models.ColBERT( | |
model_name_or_path=model_name, | |
) | |
self.index = indexes.PLAID( | |
override=override, | |
index_name=f"{dataset_name}_{model_name.split('/')[-1]}", | |
) | |
self.id_to_doc = SqliteDict( | |
f"./indexes/{dataset_name}_{model_name.split('/')[-1]}/id_to_doc.sqlite", | |
outer_stack=False, | |
) | |
self.retriever = retrieve.ColBERT(index=self.index) | |
if override: | |
tasks = mteb.get_tasks(tasks=["BrightRetrieval"]) | |
tasks[0].load_data() | |
for doc, doc_id in zip( | |
list(tasks[0].corpus[dataset_name]["standard"].values()), | |
list(tasks[0].corpus[dataset_name]["standard"].keys()), | |
): | |
self.id_to_doc[doc_id] = doc | |
self.id_to_doc.commit() # Don't forget to commit to save changes! | |
documents_embeddings = self.model.encode( | |
sentences=list(tasks[0].corpus[dataset_name]["standard"].values()), | |
batch_size=100, | |
is_query=False, | |
show_progress_bar=True, | |
) | |
self.index.add_documents( | |
documents_ids=list(tasks[0].corpus[dataset_name]["standard"].keys()), | |
documents_embeddings=documents_embeddings, | |
) | |
self.logger.info("Created PyLate MCP Server") | |
def get_document( | |
self, | |
docid: str, | |
) -> Optional[Dict[str, Any]]: | |
"""Retrieve full document by document ID.""" | |
return {"docid": docid, "text": self.id_to_doc[docid]} | |
def search(self, query: str, k: int = 10) -> List[Dict[str, Any]]: | |
"""Perform multi-vector search on specified index.""" | |
try: | |
query_embeddings = self.model.encode( | |
sentences=[query], | |
is_query=True, | |
show_progress_bar=True, | |
batch_size=32, | |
) | |
scores = self.retriever.retrieve(queries_embeddings=query_embeddings, k=20) | |
results = [] | |
for score in scores[0]: | |
results.append( | |
{ | |
"docid": score["id"], | |
"score": round(score["score"], 5), | |
"text": self.id_to_doc[score["id"]], | |
# "text": self.id_to_doc[score["id"]][:200] + "…" | |
# if len(self.id_to_doc[score["id"]]) > 200 | |
# else self.id_to_doc[score["id"]], | |
} | |
) | |
return results | |
except Exception as e: | |
self.logger.error(f"Search failed: {e}") | |
raise RuntimeError(f"Search operation failed: {e}") | |