Spaces:
Build error
Build error
# faiss_wrapper.py | |
import faiss | |
import numpy as np | |
class FAISS_search: | |
def __init__(self, embedding_model): | |
self.documents = [] | |
self.doc_ids = [] | |
self.embedding_model = embedding_model | |
self.dimension = len(embedding_model.encode("embedding")) | |
self.index = faiss.IndexIDMap(faiss.IndexFlatL2(self.dimension)) | |
def add_document(self, doc_id, new_doc): | |
self.documents.append(new_doc) | |
self.doc_ids.append(doc_id) | |
# Encode and add document with its index as ID | |
embedding = self.embedding_model.encode([new_doc], convert_to_numpy=True).astype('float32') | |
if embedding.size == 0: | |
print("No documents to add to FAISS index.") | |
return | |
idx = len(self.documents) - 1 | |
id_array = np.array([idx]).astype('int64') | |
self.index.add_with_ids(embedding, id_array) | |
def remove_document(self, index): | |
if 0 <= index < len(self.documents): | |
del self.documents[index] | |
del self.doc_ids[index] | |
# Rebuild the index | |
self.build_index() | |
else: | |
print(f"Index {index} is out of bounds.") | |
def build_index(self): | |
embeddings = self.embedding_model.encode(self.documents, convert_to_numpy=True).astype('float32') | |
idx_array = np.arange(len(self.documents)).astype('int64') | |
self.index = faiss.IndexIDMap(faiss.IndexFlatL2(self.dimension)) | |
self.index.add_with_ids(embeddings, idx_array) | |
def search(self, query, k): | |
if self.index.ntotal == 0: | |
# No documents in the index | |
print("FAISS index is empty. No results can be returned.") | |
return np.array([]), np.array([]) # Return empty arrays for distances and indices | |
query_embedding = self.embedding_model.encode([query], convert_to_numpy=True).astype('float32') | |
distances, indices = self.index.search(query_embedding, k) | |
return distances[0], indices[0] | |
def clear_documents(self) -> None: | |
""" | |
Clears all documents from the FAISS index. | |
""" | |
self.documents = [] | |
self.doc_ids = [] | |
# Reset the FAISS index | |
self.index = faiss.IndexIDMap(faiss.IndexFlatL2(self.dimension)) | |
print("FAISS documents cleared and index reset.") | |
def get_document(self, doc_id: str) -> str: | |
""" | |
Retrieves a document by its document ID. | |
Parameters: | |
- doc_id (str): The ID of the document to retrieve. | |
Returns: | |
- str: The document text if found, otherwise an empty string. | |
""" | |
try: | |
index = self.doc_ids.index(doc_id) | |
return self.documents[index] | |
except ValueError: | |
print(f"Document ID {doc_id} not found.") | |
return "" | |