Spaces:
Runtime error
Runtime error
File size: 5,517 Bytes
eccde2c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import faiss
import numpy as np
import os
import pickle
from tqdm import tqdm
# Create a class for a flat index
class IndexFlat:
def __init__(self, dimension):
# Initialize a Faiss flat index with L2 distance
self.index = faiss.IndexFlatL2(dimension)
def add(self, vectors):
# Add vectors to the index
self.index.add(np.array(vectors))
def delete(self, ids):
# Remove vectors from the index by their IDs
self.index.remove_ids(np.array(ids))
def search(self, vectors, k):
# Search for the k-nearest neighbors of the given vectors
return self.index.search(np.array(vectors), k)
# Create a class for an IVF (Inverted File) index
class IndexIVF:
def __init__(self, dimension, nlists=100, nprobe=10):
# Initialize a Faiss flat index and an IVF index with inner product metric
self.index_flat = faiss.IndexFlatL2(dimension)
self.index = faiss.IndexIVFFlat(self.index_flat, dimension, nlists, faiss.METRIC_INNER_PRODUCT)
self.index.nprobe = nprobe
def add(self, vectors):
# Train and add vectors to the index
self.index.train(np.array(vectors))
self.index.add(np.array(vectors))
def delete(self, ids):
# Remove vectors from the index by their IDs
self.index.remove_ids(np.array(ids))
def search(self, vectors, k):
# Search for the k-nearest neighbors of the given vectors
return self.index.search(np.array(vectors), k)
# Create a class for managing Faiss vector storage
class FaissVectorStore:
def __init__(self, dimension=324, nlists=100, nprobe=10):
self.dimension = dimension
self.nlists = nlists
self.nprobe = nprobe
self.index = None
self.documents_db = {}
def add(self, documents):
ids = range(0, len(self.documents_db) + len(documents))
db_vectors, db_documents, db_docs_ids = [], [], []
# Collect existing document vectors and documents
for doc_id in self.documents_db:
db_vectors.append(self.documents_db[doc_id]['vector'])
db_documents.append(self.documents_db[doc_id]['document'])
db_docs_ids.append(doc_id)
# Add new document vectors and documents
for doc_id in documents:
db_vectors.append(documents[doc_id]['vector'])
db_documents.append(documents[doc_id]['document'])
db_docs_ids.append(doc_id)
if len(db_vectors) < 10000:
self.index = IndexFlat(self.dimension)
else:
self.index = IndexIVF(self.dimension, self.nlists, self.nprobe)
self.index.add(np.array(db_vectors))
self.documents_db = {}
for i, doc_id in enumerate(db_docs_ids):
self.documents_db[doc_id] = {'vector': db_vectors[i], 'document': db_documents[i], 'index_id': i}
def delete(self, documents_ids):
# Delete vectors from the index by document IDs
index_ids_to_delete = []
for doc_id in documents_ids:
if doc_id in self.documents_db:
index_ids_to_delete.append(self.documents_db[doc_id]['index_id'])
self.index.delete(index_ids_to_delete)
self.documents_db = {k: v for k, v in self.documents_db.items() if k not in documents_ids}
def query(self, query_vector, k):
# Query for the top k nearest neighbors to the query_vector
_, I = self.index.search(query_vector, k)
documents = []
for doc_id in self.documents_db:
if self.documents_db[doc_id]['index_id'] in I[0]:
documents.append(self.documents_db[doc_id]['document'])
return documents
def write(self,database_path):
# Save the index and documents to files
if not os.path.exists(database_path):
os.makedirs(database_path)
faiss_path = os.path.join(database_path, 'index.faiss')
document_path = os.path.join(database_path, 'documents.pkl')
faiss.write_index(self.index.index, faiss_path)
with open(document_path, 'wb') as f:
pickle.dump(self.documents_db, f)
def read(self,database_path):
# Read the index and documents from files
faiss_path = os.path.join(database_path, 'index.faiss')
document_path = os.path.join(database_path, 'documents.pkl')
self.index = faiss.read_index(faiss_path)
with open(document_path, 'rb') as f:
self.documents_db = pickle.load(f)
@classmethod
def from_documents(cls, documents, dimension, nlists, nprobe):
vector_store = cls(dimension, nlists, nprobe)
vector_store.add(documents)
return vector_store
@classmethod
def as_retriever(cls, database_path):
vector_store = cls()
vector_store.read(database_path)
return vector_store
if __name__ == '__main__':
nb = 20000
d = 50
database_path = 'db_path'
if not os.path.exists(database_path):
os.makedirs(database_path)
documents = {}
for i in range(nb):
id = f'id_{i}'
texts = f'text_{i}'
vectors = np.random.random((d)).astype('float32')
documents[id] = {'document': texts, 'vector': vectors}
vector_store = FaissVectorStore.from_documents(documents, dimension=50, nlists=100, nprobe=10)
query_vector = np.random.random((1, d)).astype('float32')
nearest_neighbors = vector_store.query(query_vector, k=5)
print(nearest_neighbors)
|