File size: 5,420 Bytes
b7be7da
 
 
 
 
 
1bea5ac
 
 
b7be7da
1bea5ac
b7be7da
1bea5ac
 
 
 
 
 
 
b7be7da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bea5ac
 
 
 
 
 
 
 
 
 
 
df5d241
1bea5ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df5d241
1bea5ac
 
 
 
 
 
df5d241
1bea5ac
6838503
 
1bea5ac
 
 
 
 
 
6838503
1bea5ac
 
 
 
 
 
 
6838503
1bea5ac
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# imports
from abc import ABC, abstractmethod
from typing import Optional, Union, Sequence, Dict, Mapping, List, Any
from typing_extensions import TypedDict
from chroma_datasets.types import AddEmbedding, Datapoint
from chroma_datasets.utils import load_huggingface_dataset, to_chroma_schema
from chromadb.utils import embedding_functions
import os
from dotenv import load_dotenv

HF_API_KEY = os.environ.get("HF_API_KEY")

ef_instruction_dict = {
    "HuggingFaceEmbeddingFunction": """
        from chromadb.utils import embedding_functions
        hf_ef = embedding_functions.huggingface_embedding_function.HuggingFaceEmbeddingFunction(api_key={HF_API_KEY}, model_name="mixedbread-ai/mxbai-embed-large-v1")
        
    """
}

class Dataset(ABC):
    """
        Abstract class for a dataset

        All datasets should inherit from this class

        Properties:
            hf_data: the raw data from huggingface
            embedding_function: the embedding function used to generate the embeddings
            embeddingFunctionInstructions: tell the user how to set up the embedding function
    """
    hf_dataset_name: str
    hf_data: Any
    embedding_function: str
    embedding_function_instructions: str

    @classmethod
    def load_data(cls):
        cls.hf_data = load_huggingface_dataset(
            cls.hf_dataset_name,
            split_name="data"
        )

    @classmethod
    def raw_text(cls) -> str:
        if cls.hf_data is None:
            cls.load_data()
        return "\n".join(cls.hf_data["document"])
    
    @classmethod
    def chunked(cls) -> List[Datapoint]:
        if cls.hf_data is None:
            cls.load_data()
        return cls.hf_data
    
    @classmethod
    def to_chroma(cls) -> AddEmbedding:
        return to_chroma_schema(cls.chunked())
    

class Memoires_DS(Dataset):
    """
    """
    hf_data = None
    hf_dataset_name = "eliot-hub/memoires_vec_800"
    embedding_function = "HuggingFaceEmbeddingFunction"
    embedding_function_instructions = ef_instruction_dict[embedding_function]




def import_into_chroma(chroma_client, dataset, collection_name=None, embedding_function=None, batch_size=5000):
    """
    Imports a dataset into Chroma in batches.

    Args:
        chroma_client (ChromaClient): The ChromaClient to use.
        collection_name (str): The name of the collection to load the dataset into.
        dataset (AddEmbedding): The dataset to load.
        embedding_function (Optional[Callable[[str], np.ndarray]]): A function that takes a string and returns an embedding.
        batch_size (int): The size of each batch to load.
    """
    # if chromadb is not installed, raise an error
    try:
        import chromadb
        from chromadb.utils import embedding_functions
    except ImportError:
        raise ImportError("Please install chromadb to use this function. `pip install chromadb`")

    ef = None

    if dataset.embedding_function is not None:
        if embedding_function is None:
            error_msg = "See documentation"
            if dataset.embedding_function_instructions is not None:
                error_msg = dataset.embedding_function_instructions

            raise ValueError(f"""
                             Dataset requires embedding function: {dataset.embedding_function}.
                             {error_msg}
                             """)

        if embedding_function.__class__.__name__ != dataset.embedding_function:
            raise ValueError(f"Please use {dataset.embedding_function} as the embedding function for this dataset. You passed {embedding_function.__class__.__name__}")

    if embedding_function is not None:
        ef = embedding_function

    # if collection_name is None, get the name from the dataset type 
    if collection_name is None:
        collection_name = dataset.__name__

    if ef is None:
        ef = embedding_functions.DefaultEmbeddingFunction()

    print("########### Init collection ###########")
    collection = chroma_client.create_collection(
        collection_name,
        embedding_function=ef
    )

    # Retrieve the mapped data
    print("########### Init to_chroma ###########")
    mapped_data = dataset.to_chroma()
    del dataset
    
    # Split the data into batches and add them to the collection
    def chunk_data(data, size):
        """Helper function to split data into batches."""
        for i in range(0, len(data), size):
            yield data[i:i+size]

    print("########### Chunking ###########")
    ids_batches = list(chunk_data(mapped_data["ids"], batch_size))
    metadatas_batches = list(chunk_data(mapped_data["metadatas"], batch_size))
    documents_batches = list(chunk_data(mapped_data["documents"], batch_size))
    embeddings_batches = list(chunk_data(mapped_data["embeddings"], batch_size))

    total_docs = len(mapped_data["ids"])

    print("########### Iterating batches ###########")
    for i, (ids, metadatas, documents, embeddings) in enumerate(zip(ids_batches, metadatas_batches, documents_batches, embeddings_batches)):
        collection.add(
            ids=ids,
            metadatas=metadatas,
            documents=documents,
            embeddings=embeddings,
        )
        print(f"Batch {i+1}/{len(ids_batches)}: Loaded {len(ids)} documents.")

    print(f"Successfully loaded {total_docs} documents into the collection named: {collection_name}")
    

    return collection