Spaces:
Build error
Build error
File size: 4,399 Bytes
d660b02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
from loguru import logger
from llm_engineering.domain.base import NoSQLBaseDocument, VectorBaseDocument
from llm_engineering.domain.types import DataCategory
from .chunking_data_handlers import (
ArticleChunkingHandler,
ChunkingDataHandler,
PostChunkingHandler,
RepositoryChunkingHandler,
)
from .cleaning_data_handlers import (
ArticleCleaningHandler,
CleaningDataHandler,
PostCleaningHandler,
RepositoryCleaningHandler,
)
from .embedding_data_handlers import (
ArticleEmbeddingHandler,
EmbeddingDataHandler,
PostEmbeddingHandler,
QueryEmbeddingHandler,
RepositoryEmbeddingHandler,
)
class CleaningHandlerFactory:
@staticmethod
def create_handler(data_category: DataCategory) -> CleaningDataHandler:
if data_category == DataCategory.POSTS:
return PostCleaningHandler()
elif data_category == DataCategory.ARTICLES:
return ArticleCleaningHandler()
elif data_category == DataCategory.REPOSITORIES:
return RepositoryCleaningHandler()
else:
raise ValueError("Unsupported data type")
class CleaningDispatcher:
cleaning_factory = CleaningHandlerFactory()
@classmethod
def dispatch(cls, data_model: NoSQLBaseDocument) -> VectorBaseDocument:
data_category = DataCategory(data_model.get_collection_name())
handler = cls.cleaning_factory.create_handler(data_category)
clean_model = handler.clean(data_model)
logger.info(
"Document cleaned successfully.",
data_category=data_category,
cleaned_content_len=len(clean_model.content),
)
return clean_model
class ChunkingHandlerFactory:
@staticmethod
def create_handler(data_category: DataCategory) -> ChunkingDataHandler:
if data_category == DataCategory.POSTS:
return PostChunkingHandler()
elif data_category == DataCategory.ARTICLES:
return ArticleChunkingHandler()
elif data_category == DataCategory.REPOSITORIES:
return RepositoryChunkingHandler()
else:
raise ValueError("Unsupported data type")
class ChunkingDispatcher:
cleaning_factory = ChunkingHandlerFactory
@classmethod
def dispatch(cls, data_model: VectorBaseDocument) -> list[VectorBaseDocument]:
data_category = data_model.get_category()
handler = cls.cleaning_factory.create_handler(data_category)
chunk_models = handler.chunk(data_model)
logger.info(
"Document chunked successfully.",
num=len(chunk_models),
data_category=data_category,
)
return chunk_models
class EmbeddingHandlerFactory:
@staticmethod
def create_handler(data_category: DataCategory) -> EmbeddingDataHandler:
if data_category == DataCategory.QUERIES:
return QueryEmbeddingHandler()
if data_category == DataCategory.POSTS:
return PostEmbeddingHandler()
elif data_category == DataCategory.ARTICLES:
return ArticleEmbeddingHandler()
elif data_category == DataCategory.REPOSITORIES:
return RepositoryEmbeddingHandler()
else:
raise ValueError("Unsupported data type")
class EmbeddingDispatcher:
cleaning_factory = EmbeddingHandlerFactory
@classmethod
def dispatch(
cls, data_model: VectorBaseDocument | list[VectorBaseDocument]
) -> VectorBaseDocument | list[VectorBaseDocument]:
is_list = isinstance(data_model, list)
if not is_list:
data_model = [data_model]
if len(data_model) == 0:
return []
data_category = data_model[0].get_category()
assert all(
data_model.get_category() == data_category for data_model in data_model
), "Data models must be of the same category."
handler = cls.cleaning_factory.create_handler(data_category)
embedded_chunk_model = handler.embed_batch(data_model)
if not is_list:
embedded_chunk_model = embedded_chunk_model[0]
logger.info(
"Data embedded successfully.",
data_category=data_category,
)
return embedded_chunk_model
|