chat / app /search /bm25_search.py
ariansyahdedy's picture
Add prompt edit and api key config
8d2f9d4
raw
history blame
5.66 kB
# bm25_search.py
import asyncio
from rank_bm25 import BM25Okapi
import nltk
import string
from typing import List, Set, Optional
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
def download_nltk_resources():
"""
Downloads required NLTK resources synchronously.
"""
resources = ['punkt', 'stopwords', 'wordnet', 'omw-1.4']
for resource in resources:
try:
nltk.download(resource, quiet=True)
except Exception as e:
print(f"Error downloading {resource}: {str(e)}")
class BM25_search:
# Class variable to track if resources have been downloaded
nltk_resources_downloaded = False
def __init__(self, remove_stopwords: bool = True, perform_lemmatization: bool = False):
"""
Initializes the BM25search.
Parameters:
- remove_stopwords (bool): Whether to remove stopwords during preprocessing.
- perform_lemmatization (bool): Whether to perform lemmatization on tokens.
"""
# Ensure NLTK resources are downloaded only once
if not BM25_search.nltk_resources_downloaded:
download_nltk_resources()
BM25_search.nltk_resources_downloaded = True # Mark as downloaded
self.documents: List[str] = []
self.doc_ids: List[str] = []
self.tokenized_docs: List[List[str]] = []
self.bm25: Optional[BM25Okapi] = None
self.remove_stopwords = remove_stopwords
self.perform_lemmatization = perform_lemmatization
self.stop_words: Set[str] = set(stopwords.words('english')) if remove_stopwords else set()
self.lemmatizer = WordNetLemmatizer() if perform_lemmatization else None
def preprocess(self, text: str) -> List[str]:
"""
Preprocesses the input text by lowercasing, removing punctuation,
tokenizing, removing stopwords, and optionally lemmatizing.
"""
text = text.lower().translate(str.maketrans('', '', string.punctuation))
tokens = nltk.word_tokenize(text)
if self.remove_stopwords:
tokens = [token for token in tokens if token not in self.stop_words]
if self.perform_lemmatization and self.lemmatizer:
tokens = [self.lemmatizer.lemmatize(token) for token in tokens]
return tokens
def add_document(self, doc_id: str, new_doc: str) -> None:
"""
Adds a new document to the corpus and updates the BM25 index.
"""
processed_tokens = self.preprocess(new_doc)
self.documents.append(new_doc)
self.doc_ids.append(doc_id)
self.tokenized_docs.append(processed_tokens)
# Ensure update_bm25 is awaited if required in async context
self.update_bm25()
print(f"Added document ID: {doc_id}")
async def remove_document(self, index: int) -> None:
"""
Removes a document from the corpus based on its index and updates the BM25 index.
"""
if 0 <= index < len(self.documents):
removed_doc_id = self.doc_ids[index]
del self.documents[index]
del self.doc_ids[index]
del self.tokenized_docs[index]
self.update_bm25()
print(f"Removed document ID: {removed_doc_id}")
else:
print(f"Index {index} is out of bounds.")
def update_bm25(self) -> None:
"""
Updates the BM25 index based on the current tokenized documents.
"""
if self.tokenized_docs:
self.bm25 = BM25Okapi(self.tokenized_docs)
print("BM25 index has been initialized.")
else:
print("No documents to initialize BM25.")
def get_scores(self, query: str) -> List[float]:
"""
Computes BM25 scores for all documents based on the given query.
"""
processed_query = self.preprocess(query)
print(f"Tokenized Query: {processed_query}")
if self.bm25:
return self.bm25.get_scores(processed_query)
else:
print("BM25 is not initialized.")
return []
def get_top_n_docs(self, query: str, n: int = 5) -> List[str]:
"""
Returns the top N documents for a given query.
"""
processed_query = self.preprocess(query)
if self.bm25:
return self.bm25.get_top_n(processed_query, self.documents, n)
else:
print("initialized.")
return []
def clear_documents(self) -> None:
"""
Clears all documents from the BM25 index.
"""
self.documents = []
self.doc_ids = []
self.tokenized_docs = []
self.bm25 = None # Reset BM25 index
print("BM25 documents cleared and index reset.")
def get_document(self, doc_id: str) -> str:
"""
Retrieves a document by its document ID.
Parameters:
- doc_id (str): The ID of the document to retrieve.
Returns:
- str: The document text if found, otherwise an empty string.
"""
try:
index = self.doc_ids.index(doc_id)
return self.documents[index]
except ValueError:
print(f"Document ID {doc_id} not found.")
return ""
async def initialize_bm25_search(remove_stopwords: bool = True, perform_lemmatization: bool = False) -> BM25_search:
"""
Initializes the BM25search with proper NLTK resource downloading.
"""
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, download_nltk_resources)
return BM25_search(remove_stopwords, perform_lemmatization)