File size: 5,517 Bytes
eccde2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import faiss
import numpy as np
import os
import pickle
from tqdm import tqdm

# Create a class for a flat index
class IndexFlat:
    def __init__(self, dimension):
        # Initialize a Faiss flat index with L2 distance
        self.index = faiss.IndexFlatL2(dimension)

    def add(self, vectors):
        # Add vectors to the index
        self.index.add(np.array(vectors))

    def delete(self, ids):
        # Remove vectors from the index by their IDs
        self.index.remove_ids(np.array(ids))

    def search(self, vectors, k):
        # Search for the k-nearest neighbors of the given vectors
        return self.index.search(np.array(vectors), k)

# Create a class for an IVF (Inverted File) index
class IndexIVF:
    def __init__(self, dimension, nlists=100, nprobe=10):
        # Initialize a Faiss flat index and an IVF index with inner product metric
        self.index_flat = faiss.IndexFlatL2(dimension)  
        self.index = faiss.IndexIVFFlat(self.index_flat, dimension, nlists, faiss.METRIC_INNER_PRODUCT)
        self.index.nprobe = nprobe

    def add(self, vectors):
        # Train and add vectors to the index
        self.index.train(np.array(vectors))
        self.index.add(np.array(vectors))

    def delete(self, ids):
        # Remove vectors from the index by their IDs
        self.index.remove_ids(np.array(ids))

    def search(self, vectors, k):
        # Search for the k-nearest neighbors of the given vectors
        return self.index.search(np.array(vectors), k)

# Create a class for managing Faiss vector storage
class FaissVectorStore:
    def __init__(self, dimension=324, nlists=100, nprobe=10):
        self.dimension = dimension
        self.nlists = nlists
        self.nprobe = nprobe
        self.index = None
        self.documents_db = {}

    def add(self, documents):
        ids = range(0, len(self.documents_db) + len(documents))
        db_vectors, db_documents, db_docs_ids = [], [], []

        # Collect existing document vectors and documents
        for doc_id in self.documents_db:
            db_vectors.append(self.documents_db[doc_id]['vector'])
            db_documents.append(self.documents_db[doc_id]['document'])
            db_docs_ids.append(doc_id)

        # Add new document vectors and documents
        for doc_id in documents:
            db_vectors.append(documents[doc_id]['vector'])
            db_documents.append(documents[doc_id]['document'])
            db_docs_ids.append(doc_id)

        if len(db_vectors) < 10000:
            self.index = IndexFlat(self.dimension)
        else:
            self.index = IndexIVF(self.dimension, self.nlists, self.nprobe)

        self.index.add(np.array(db_vectors))
        self.documents_db = {}
        for i, doc_id in enumerate(db_docs_ids):
            self.documents_db[doc_id] = {'vector': db_vectors[i], 'document': db_documents[i], 'index_id': i}
        
    def delete(self, documents_ids):
        # Delete vectors from the index by document IDs
        index_ids_to_delete = []
        for doc_id in documents_ids:
            if doc_id in self.documents_db:
                index_ids_to_delete.append(self.documents_db[doc_id]['index_id'])
        self.index.delete(index_ids_to_delete)
        self.documents_db = {k: v for k, v in self.documents_db.items() if k not in documents_ids}
        

    def query(self, query_vector, k):
        # Query for the top k nearest neighbors to the query_vector
        _, I = self.index.search(query_vector, k)
        documents = []
        for doc_id in self.documents_db:
            if self.documents_db[doc_id]['index_id'] in I[0]:
                documents.append(self.documents_db[doc_id]['document'])
        return documents

    def write(self,database_path):
        # Save the index and documents to files
        if not os.path.exists(database_path):
            os.makedirs(database_path)
        faiss_path = os.path.join(database_path, 'index.faiss')
        document_path = os.path.join(database_path, 'documents.pkl')
        faiss.write_index(self.index.index, faiss_path)
        with open(document_path, 'wb') as f:
            pickle.dump(self.documents_db, f)

    def read(self,database_path):
        # Read the index and documents from files
        faiss_path = os.path.join(database_path, 'index.faiss')
        document_path = os.path.join(database_path, 'documents.pkl')
        self.index = faiss.read_index(faiss_path)
        with open(document_path, 'rb') as f:
            self.documents_db = pickle.load(f)

    @classmethod
    def from_documents(cls, documents, dimension, nlists, nprobe):
        vector_store = cls(dimension, nlists, nprobe)
        vector_store.add(documents)
        return vector_store

    @classmethod
    def as_retriever(cls, database_path):
        vector_store = cls()
        vector_store.read(database_path)
        return vector_store

if __name__ == '__main__':
    nb = 20000
    d = 50
    database_path = 'db_path'

    if not os.path.exists(database_path):
        os.makedirs(database_path)

    documents = {}
    for i in range(nb):
        id = f'id_{i}'
        texts = f'text_{i}'
        vectors = np.random.random((d)).astype('float32')
        documents[id] = {'document': texts, 'vector': vectors}

    vector_store = FaissVectorStore.from_documents(documents, dimension=50, nlists=100, nprobe=10)
    query_vector = np.random.random((1, d)).astype('float32')
    nearest_neighbors = vector_store.query(query_vector, k=5)
    print(nearest_neighbors)