Spaces:
Sleeping
Sleeping
File size: 5,420 Bytes
b7be7da 1bea5ac b7be7da 1bea5ac b7be7da 1bea5ac b7be7da 1bea5ac df5d241 1bea5ac df5d241 1bea5ac df5d241 1bea5ac 6838503 1bea5ac 6838503 1bea5ac 6838503 1bea5ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
# imports
from abc import ABC, abstractmethod
from typing import Optional, Union, Sequence, Dict, Mapping, List, Any
from typing_extensions import TypedDict
from chroma_datasets.types import AddEmbedding, Datapoint
from chroma_datasets.utils import load_huggingface_dataset, to_chroma_schema
from chromadb.utils import embedding_functions
import os
from dotenv import load_dotenv
HF_API_KEY = os.environ.get("HF_API_KEY")
ef_instruction_dict = {
"HuggingFaceEmbeddingFunction": """
from chromadb.utils import embedding_functions
hf_ef = embedding_functions.huggingface_embedding_function.HuggingFaceEmbeddingFunction(api_key={HF_API_KEY}, model_name="mixedbread-ai/mxbai-embed-large-v1")
"""
}
class Dataset(ABC):
"""
Abstract class for a dataset
All datasets should inherit from this class
Properties:
hf_data: the raw data from huggingface
embedding_function: the embedding function used to generate the embeddings
embeddingFunctionInstructions: tell the user how to set up the embedding function
"""
hf_dataset_name: str
hf_data: Any
embedding_function: str
embedding_function_instructions: str
@classmethod
def load_data(cls):
cls.hf_data = load_huggingface_dataset(
cls.hf_dataset_name,
split_name="data"
)
@classmethod
def raw_text(cls) -> str:
if cls.hf_data is None:
cls.load_data()
return "\n".join(cls.hf_data["document"])
@classmethod
def chunked(cls) -> List[Datapoint]:
if cls.hf_data is None:
cls.load_data()
return cls.hf_data
@classmethod
def to_chroma(cls) -> AddEmbedding:
return to_chroma_schema(cls.chunked())
class Memoires_DS(Dataset):
"""
"""
hf_data = None
hf_dataset_name = "eliot-hub/memoires_vec_800"
embedding_function = "HuggingFaceEmbeddingFunction"
embedding_function_instructions = ef_instruction_dict[embedding_function]
def import_into_chroma(chroma_client, dataset, collection_name=None, embedding_function=None, batch_size=5000):
"""
Imports a dataset into Chroma in batches.
Args:
chroma_client (ChromaClient): The ChromaClient to use.
collection_name (str): The name of the collection to load the dataset into.
dataset (AddEmbedding): The dataset to load.
embedding_function (Optional[Callable[[str], np.ndarray]]): A function that takes a string and returns an embedding.
batch_size (int): The size of each batch to load.
"""
# if chromadb is not installed, raise an error
try:
import chromadb
from chromadb.utils import embedding_functions
except ImportError:
raise ImportError("Please install chromadb to use this function. `pip install chromadb`")
ef = None
if dataset.embedding_function is not None:
if embedding_function is None:
error_msg = "See documentation"
if dataset.embedding_function_instructions is not None:
error_msg = dataset.embedding_function_instructions
raise ValueError(f"""
Dataset requires embedding function: {dataset.embedding_function}.
{error_msg}
""")
if embedding_function.__class__.__name__ != dataset.embedding_function:
raise ValueError(f"Please use {dataset.embedding_function} as the embedding function for this dataset. You passed {embedding_function.__class__.__name__}")
if embedding_function is not None:
ef = embedding_function
# if collection_name is None, get the name from the dataset type
if collection_name is None:
collection_name = dataset.__name__
if ef is None:
ef = embedding_functions.DefaultEmbeddingFunction()
print("########### Init collection ###########")
collection = chroma_client.create_collection(
collection_name,
embedding_function=ef
)
# Retrieve the mapped data
print("########### Init to_chroma ###########")
mapped_data = dataset.to_chroma()
del dataset
# Split the data into batches and add them to the collection
def chunk_data(data, size):
"""Helper function to split data into batches."""
for i in range(0, len(data), size):
yield data[i:i+size]
print("########### Chunking ###########")
ids_batches = list(chunk_data(mapped_data["ids"], batch_size))
metadatas_batches = list(chunk_data(mapped_data["metadatas"], batch_size))
documents_batches = list(chunk_data(mapped_data["documents"], batch_size))
embeddings_batches = list(chunk_data(mapped_data["embeddings"], batch_size))
total_docs = len(mapped_data["ids"])
print("########### Iterating batches ###########")
for i, (ids, metadatas, documents, embeddings) in enumerate(zip(ids_batches, metadatas_batches, documents_batches, embeddings_batches)):
collection.add(
ids=ids,
metadatas=metadatas,
documents=documents,
embeddings=embeddings,
)
print(f"Batch {i+1}/{len(ids_batches)}: Loaded {len(ids)} documents.")
print(f"Successfully loaded {total_docs} documents into the collection named: {collection_name}")
return collection
|