Spaces:
Runtime error
Runtime error
import faiss | |
import numpy as np | |
import os | |
import pickle | |
from tqdm import tqdm | |
# Create a class for a flat index | |
class IndexFlat: | |
def __init__(self, dimension): | |
# Initialize a Faiss flat index with L2 distance | |
self.index = faiss.IndexFlatL2(dimension) | |
def add(self, vectors): | |
# Add vectors to the index | |
self.index.add(np.array(vectors)) | |
def delete(self, ids): | |
# Remove vectors from the index by their IDs | |
self.index.remove_ids(np.array(ids)) | |
def search(self, vectors, k): | |
# Search for the k-nearest neighbors of the given vectors | |
return self.index.search(np.array(vectors), k) | |
# Create a class for an IVF (Inverted File) index | |
class IndexIVF: | |
def __init__(self, dimension, nlists=100, nprobe=10): | |
# Initialize a Faiss flat index and an IVF index with inner product metric | |
self.index_flat = faiss.IndexFlatL2(dimension) | |
self.index = faiss.IndexIVFFlat(self.index_flat, dimension, nlists, faiss.METRIC_INNER_PRODUCT) | |
self.index.nprobe = nprobe | |
def add(self, vectors): | |
# Train and add vectors to the index | |
self.index.train(np.array(vectors)) | |
self.index.add(np.array(vectors)) | |
def delete(self, ids): | |
# Remove vectors from the index by their IDs | |
self.index.remove_ids(np.array(ids)) | |
def search(self, vectors, k): | |
# Search for the k-nearest neighbors of the given vectors | |
return self.index.search(np.array(vectors), k) | |
# Create a class for managing Faiss vector storage | |
class FaissVectorStore: | |
def __init__(self, dimension=324, nlists=100, nprobe=10): | |
self.dimension = dimension | |
self.nlists = nlists | |
self.nprobe = nprobe | |
self.index = None | |
self.documents_db = {} | |
def add(self, documents): | |
ids = range(0, len(self.documents_db) + len(documents)) | |
db_vectors, db_documents, db_docs_ids = [], [], [] | |
# Collect existing document vectors and documents | |
for doc_id in self.documents_db: | |
db_vectors.append(self.documents_db[doc_id]['vector']) | |
db_documents.append(self.documents_db[doc_id]['document']) | |
db_docs_ids.append(doc_id) | |
# Add new document vectors and documents | |
for doc_id in documents: | |
db_vectors.append(documents[doc_id]['vector']) | |
db_documents.append(documents[doc_id]['document']) | |
db_docs_ids.append(doc_id) | |
if len(db_vectors) < 10000: | |
self.index = IndexFlat(self.dimension) | |
else: | |
self.index = IndexIVF(self.dimension, self.nlists, self.nprobe) | |
self.index.add(np.array(db_vectors)) | |
self.documents_db = {} | |
for i, doc_id in enumerate(db_docs_ids): | |
self.documents_db[doc_id] = {'vector': db_vectors[i], 'document': db_documents[i], 'index_id': i} | |
def delete(self, documents_ids): | |
# Delete vectors from the index by document IDs | |
index_ids_to_delete = [] | |
for doc_id in documents_ids: | |
if doc_id in self.documents_db: | |
index_ids_to_delete.append(self.documents_db[doc_id]['index_id']) | |
self.index.delete(index_ids_to_delete) | |
self.documents_db = {k: v for k, v in self.documents_db.items() if k not in documents_ids} | |
def query(self, query_vector, k): | |
# Query for the top k nearest neighbors to the query_vector | |
_, I = self.index.search(query_vector, k) | |
documents = [] | |
for doc_id in self.documents_db: | |
if self.documents_db[doc_id]['index_id'] in I[0]: | |
documents.append(self.documents_db[doc_id]['document']) | |
return documents | |
def write(self,database_path): | |
# Save the index and documents to files | |
if not os.path.exists(database_path): | |
os.makedirs(database_path) | |
faiss_path = os.path.join(database_path, 'index.faiss') | |
document_path = os.path.join(database_path, 'documents.pkl') | |
faiss.write_index(self.index.index, faiss_path) | |
with open(document_path, 'wb') as f: | |
pickle.dump(self.documents_db, f) | |
def read(self,database_path): | |
# Read the index and documents from files | |
faiss_path = os.path.join(database_path, 'index.faiss') | |
document_path = os.path.join(database_path, 'documents.pkl') | |
self.index = faiss.read_index(faiss_path) | |
with open(document_path, 'rb') as f: | |
self.documents_db = pickle.load(f) | |
def from_documents(cls, documents, dimension, nlists, nprobe): | |
vector_store = cls(dimension, nlists, nprobe) | |
vector_store.add(documents) | |
return vector_store | |
def as_retriever(cls, database_path): | |
vector_store = cls() | |
vector_store.read(database_path) | |
return vector_store | |
if __name__ == '__main__': | |
nb = 20000 | |
d = 50 | |
database_path = 'db_path' | |
if not os.path.exists(database_path): | |
os.makedirs(database_path) | |
documents = {} | |
for i in range(nb): | |
id = f'id_{i}' | |
texts = f'text_{i}' | |
vectors = np.random.random((d)).astype('float32') | |
documents[id] = {'document': texts, 'vector': vectors} | |
vector_store = FaissVectorStore.from_documents(documents, dimension=50, nlists=100, nprobe=10) | |
query_vector = np.random.random((1, d)).astype('float32') | |
nearest_neighbors = vector_store.query(query_vector, k=5) | |
print(nearest_neighbors) | |