File size: 3,251 Bytes
40a2cd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
__import__("pysqlite3")
import sys

sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")


import uuid
from collections import defaultdict
from typing import Any, List

import chromadb
import numpy as np
from chromadb import Collection
from embeddings import Embedding
from PIL.Image import Image

from utils import base64_to_image


class ChromaStore:
    def __init__(
        self,
        collection_name: str,
        storage_path: str = "./chroma",
        database: str = "database",
        metadata: dict = {"hnsw:space": "cosine"},
    ) -> None:
        """Initiate Chromadb
        - collection_name(str): name of the collection
        - metadata(dict): available options for 'hnsw:space' are 'l2', 'ip' or 'cosine'.
        """

        self.collection_name = collection_name
        self.metadata = metadata
        self.storage_path = storage_path
        self.database = database

        self.client = chromadb.PersistentClient(path=self.storage_path)

    def _health_check(self) -> bool:
        return isinstance(self.client.heartbeat(), int)

    def generate_embeddings(
        self, images: List[Image], embedding: Embedding
    ) -> np.ndarray:
        return embedding.encode_images(images)

    def create(self):
        collection = self.client.get_or_create_collection(
            name=self.collection_name,
        )
        return collection

    def add(
        self,
        collection: Collection,
        embeddings: List[float],
        documents: List[str],
        ids: List[str],
    ):
        """Add embeddings, documents to index or collection.

        Args:
        - collection: created collection.
        - embeddings: list of image embeddings
        - documents: list of base64 string of images
        - ids: list of ids for images."""
        try:
            collection.add(
                embeddings=embeddings,
                ids=ids,
                documents=documents,
            )
        except Exception as e:
            raise Exception(f"Failed to add documents to Chroma store. {e}")

    def query(
        self,
        collection: Collection,
        query_embedding: List[float],
        top_k: int = 3,
    ) -> list:
        """Retrieve relevant images from chroma database.

        Args:
        - collection: created collection.
        - query_embedding: query image embedding.
        - top_k (int): top k images to retrieve.

        Returns:
        - list of images along with their score.
        """
        result = collection.query(query_embeddings=query_embedding, n_results=top_k)
        relevant_images = [
            base64_to_image(img_str) for img_str in result["documents"][0]
        ]
        scores = [round(score, 3) for score in result["distances"][0]]
        return list(zip(relevant_images, scores))

    def delete(self, collection_name: str):
        try:
            self.client.delete_collection(collection_name)
            return True
        except Exception as e:
            raise Exception("Failed to delete collection", e)

    @staticmethod
    def collection_info(collection: Collection):
        info = defaultdict(str)
        info["count"] = collection.count()
        info["top_10_items"] = collection.peek()
        return info