File size: 7,207 Bytes
bd3532f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
from typing import Dict, Any, Optional, List

import chromadb
from chromadb.config import Settings
from chromadb.api.types import (
    Where,
    GetResult,
    QueryResult,
)

from ..embedding_provider import EmbeddingProvider
from .vector_store import VectorStore


class ChromaDB(VectorStore):
    """
    ChromaDB is an example of a vector-store class implementation.

    See more:
    https://github.com/chroma-core/chroma
    """

    def __init__(
            self, 
            configs: Dict[str, Any] = {},
            db_path: str = ".chromadb", 
            embedding_function: Optional[EmbeddingProvider] = None,
            collection_name: Optional[str] = None,
    ) -> None:
        self.client = chromadb.PersistentClient(
            path=db_path
        )
        self.configs = configs

        self.embedding_function = embedding_function
        self._collection_name = collection_name

        self.collection = self.client.get_or_create_collection(
            name = self.collection_name or "default_collection"
        )

        # self.logger = get_logger(self.__class__.__name__)

    @property
    def db_path(self) -> str:
        return self.client.get_settings().persist_directory
    
    @db_path.setter
    def db_path(self, value: str) -> None:
        self.client = chromadb.PersistentClient(path=value)

        self.collection = self.client.get_or_create_collection(
            name = self.collection_name or "default_collection"
        )

    @property
    def collection_name(self):
        return self._collection_name
    
    @collection_name.setter
    def collection_name(self, value):
        self._collection_name = value
        self.collection.modify(name=value)

    def add_data(
            self, 
            documents: List[str],
            ids: List[str],
            metadatas: Optional[List[Dict[str, Any]]] = None, 
            **optional_kwargs
    ) -> None:
        """
        Add data to the collection by creating embeddings for them.

        Args:
            documents (List[str]): List of documents to add.
            ids (List[str]): List of ids for the documents.
            metadatas (Optional[List[Dict[str, Any]]]): List of metadata for the documents.
            **optional_kwargs: Additional keyword arguments (see collection.add for more).
        """

        try:
            params = {
                "documents": documents,
                "ids": ids,
                **optional_kwargs
            }

            params["metadatas"] = metadatas or None
            
            # If an embedding function is provided, create embeddings for the documents
            if self.embedding_function:
                embeddings = self.embedding_function.embed_documents(documents)
                params["embeddings"] = embeddings
            
            self.collection.add(**params)
        except Exception as e:
            # self.logger.error(f"Error adding data to collection: {e}")
            print(f"Error adding data to collection: {e}")
            raise e

    def search(
            self, 
            query_text: Optional[List[str]] = None,  
            query_embedding: Optional[List[List[float]]] = None,
            n_results: int = 10,
            **optional_kwargs
    ) -> QueryResult:
        """
        Query the collection for similar documents.

        Args:
            query_text (Optional[List[str]]): List of query texts.
            query_embedding (Optional[List[List[float]]]): List of query embeddings.
            n_results (int): Number of results to return.
            **optional_kwargs: Additional keyword arguments (see collection.query for more).

        Returns:
            QueryResult: The result of the query.
        """

        try:
            if query_text is None and query_embedding is None:
                raise ValueError("Either query_text or query_embedding must be provided.")
            
            params = {
                "n_results": n_results,
                **optional_kwargs
            }

            if query_text and query_embedding is None:
                if self.embedding_function:
                    query_embedding = self.embedding_function.embed_query(query_text)
                    params["query_embeddings"] = query_embedding
                else:
                    params["query_text"] = query_text

            elif query_embedding and query_text is None:
                params["query_embeddings"] = query_embedding

            elif query_embedding and query_text:
                params["query_embeddings"] = query_embedding
                
                if self.embedding_function:
                    embeddings = self.embedding_function.embed_query(query_text)
                    params["query_embeddings"] = query_embedding.extend(embeddings)
                else:
                    params["query_text"] = query_text

            return self.collection.query(**params)
        except Exception as e:
            # self.logger.error(f"Error querying data from collection: {e}")
            print(f"Error querying data from collection: {e}")
            raise e
        
    def query_by_id_or_metadata(
            self, 
            ids: Optional[List[str]] = None,
	        where: Optional[Where] = None,
            n_results: int = 10,
            **optional_kwargs
    ) -> GetResult:
        """
        Query the collection for similar documents.

        Args:
            ids (Optional[List[str]]): List of ids to query.
            where (Optional[Where]): Where clause to query.
            n_results (int): Number of results to return.
            **optional_kwargs: Additional keyword arguments (see collection.get for more).

        Returns:
            GetResult: The result of the query.
        """

        try:
            if ids is None and where is None:
                raise ValueError("Either ids or where must be provided.")
            
            params = {
                "n_results": n_results,
                **optional_kwargs
            }

            if ids:
                params["ids"] = ids
            if where:
                params["where"] = where

            return self.collection.get(**params)
        except Exception as e:
            # self.logger.error(f"Error querying data from collection: {e}")
            print(f"Error querying data from collection: {e}")
            raise e
        
    def delete_collection(self, collection_name: Optional[str] = None) -> None:
        """
        Delete a specific collection from the ChromaDB.

        Args:
            collection_name (Optional[str]): Name of collection to delete. 
                                             Uses class's collection_name if not provided.
        """
        try:
            target_collection = collection_name or self.collection_name
            if not target_collection:
                raise ValueError("No collection name provided")
            
            self.client.delete_collection(name=target_collection)
            print(f"Collection '{target_collection}' deleted successfully.")
        except Exception as e:
            print(f"Error deleting collection: {e}")