Spaces:
Build error
Build error
import asyncio | |
from rank_bm25 import BM25Okapi | |
# import nltk | |
import string | |
from typing import List, Set, Optional | |
# from nltk.corpus import stopwords | |
# from nltk.stem import WordNetLemmatizer | |
import os | |
# Commented out this function that downloads NLTK resources. | |
# def download_nltk_resources(): | |
# """ | |
# Downloads required NLTK resources synchronously. | |
# """ | |
# resources = ['punkt', 'stopwords', 'wordnet', 'omw-1.4'] | |
# nltk_data_path = "/tmp/nltk_data" | |
# os.makedirs(nltk_data_path, exist_ok=True) | |
# nltk.data.path.append(nltk_data_path) | |
# for resource in resources: | |
# try: | |
# nltk.download(resource, download_dir=nltk_data_path, quiet=True) | |
# except Exception as e: | |
# print(f"Error downloading {resource}: {str(e)}") | |
class BM25_search: | |
nltk_resources_downloaded = False | |
def __init__(self, remove_stopwords: bool = True, perform_lemmatization: bool = False): | |
""" | |
Initializes the BM25search. | |
""" | |
# Commented out NLTK resource initialization | |
# if not BM25_search.nltk_resources_downloaded: | |
# download_nltk_resources() | |
# BM25_search.nltk_resources_downloaded = True | |
self.documents: List[str] = [] | |
self.doc_ids: List[str] = [] | |
self.tokenized_docs: List[List[str]] = [] | |
self.bm25: Optional[BM25Okapi] = None | |
self.remove_stopwords = remove_stopwords | |
self.perform_lemmatization = perform_lemmatization | |
# Commented out NLTK-specific tools | |
# self.stop_words: Set[str] = set(stopwords.words('english')) if remove_stopwords else set() | |
# self.lemmatizer = WordNetLemmatizer() if perform_lemmatization else None | |
def preprocess(self, text: str) -> List[str]: | |
""" | |
Preprocesses the input text by lowercasing and removing punctuation. | |
NLTK-related tokenization, stopword removal, and lemmatization are commented out. | |
""" | |
text = text.lower().translate(str.maketrans('', '', string.punctuation)) | |
# tokens = nltk.word_tokenize(text) # Commented out NLTK tokenization | |
tokens = text.split() # Basic tokenization as a fallback | |
# if self.remove_stopwords: | |
# tokens = [token for token in tokens if token not in self.stop_words] | |
# if self.perform_lemmatization and self.lemmatizer: | |
# tokens = [self.lemmatizer.lemmatize(token) for token in tokens] | |
return tokens | |
def add_document(self, doc_id: str, new_doc: str) -> None: | |
""" | |
Adds a new document to the corpus and updates the BM25 index. | |
""" | |
processed_tokens = self.preprocess(new_doc) | |
self.documents.append(new_doc) | |
self.doc_ids.append(doc_id) | |
self.tokenized_docs.append(processed_tokens) | |
self.update_bm25() | |
print(f"Added document ID: {doc_id}") | |
async def remove_document(self, index: int) -> None: | |
""" | |
Removes a document from the corpus based on its index and updates the BM25 index. | |
""" | |
if 0 <= index < len(self.documents): | |
removed_doc_id = self.doc_ids[index] | |
del self.documents[index] | |
del self.doc_ids[index] | |
del self.tokenized_docs[index] | |
self.update_bm25() | |
print(f"Removed document ID: {removed_doc_id}") | |
else: | |
print(f"Index {index} is out of bounds.") | |
def update_bm25(self) -> None: | |
""" | |
Updates the BM25 index based on the current tokenized documents. | |
""" | |
if self.tokenized_docs: | |
self.bm25 = BM25Okapi(self.tokenized_docs) | |
print("BM25 index has been initialized.") | |
else: | |
print("No documents to initialize BM25.") | |
def get_scores(self, query: str) -> List[float]: | |
""" | |
Computes BM25 scores for all documents based on the given query. | |
""" | |
processed_query = self.preprocess(query) | |
print(f"Tokenized Query: {processed_query}") | |
if self.bm25: | |
return self.bm25.get_scores(processed_query) | |
else: | |
print("BM25 is not initialized.") | |
return [] | |
def get_top_n_docs(self, query: str, n: int = 5) -> List[str]: | |
""" | |
Returns the top N documents for a given query. | |
""" | |
processed_query = self.preprocess(query) | |
if self.bm25: | |
return self.bm25.get_top_n(processed_query, self.documents, n) | |
else: | |
print("BM25 is not initialized.") | |
return [] | |
def clear_documents(self) -> None: | |
""" | |
Clears all documents from the BM25 index. | |
""" | |
self.documents = [] | |
self.doc_ids = [] | |
self.tokenized_docs = [] | |
self.bm25 = None | |
print("BM25 documents cleared and index reset.") | |
def get_document(self, doc_id: str) -> str: | |
""" | |
Retrieves a document by its document ID. | |
""" | |
try: | |
index = self.doc_ids.index(doc_id) | |
return self.documents[index] | |
except ValueError: | |
print(f"Document ID {doc_id} not found.") | |
return "" | |
async def initialize_bm25_search(remove_stopwords: bool = True, perform_lemmatization: bool = False) -> BM25_search: | |
""" | |
Initializes the BM25search. | |
""" | |
# Removed NLTK resource download from async context | |
return BM25_search(remove_stopwords, perform_lemmatization) | |