Spaces:
Build error
Build error
from loguru import logger | |
from llm_engineering.domain.base import NoSQLBaseDocument, VectorBaseDocument | |
from llm_engineering.domain.types import DataCategory | |
from .chunking_data_handlers import ( | |
ArticleChunkingHandler, | |
ChunkingDataHandler, | |
PostChunkingHandler, | |
RepositoryChunkingHandler, | |
) | |
from .cleaning_data_handlers import ( | |
ArticleCleaningHandler, | |
CleaningDataHandler, | |
PostCleaningHandler, | |
RepositoryCleaningHandler, | |
) | |
from .embedding_data_handlers import ( | |
ArticleEmbeddingHandler, | |
EmbeddingDataHandler, | |
PostEmbeddingHandler, | |
QueryEmbeddingHandler, | |
RepositoryEmbeddingHandler, | |
) | |
class CleaningHandlerFactory: | |
def create_handler(data_category: DataCategory) -> CleaningDataHandler: | |
if data_category == DataCategory.POSTS: | |
return PostCleaningHandler() | |
elif data_category == DataCategory.ARTICLES: | |
return ArticleCleaningHandler() | |
elif data_category == DataCategory.REPOSITORIES: | |
return RepositoryCleaningHandler() | |
else: | |
raise ValueError("Unsupported data type") | |
class CleaningDispatcher: | |
cleaning_factory = CleaningHandlerFactory() | |
def dispatch(cls, data_model: NoSQLBaseDocument) -> VectorBaseDocument: | |
data_category = DataCategory(data_model.get_collection_name()) | |
handler = cls.cleaning_factory.create_handler(data_category) | |
clean_model = handler.clean(data_model) | |
logger.info( | |
"Document cleaned successfully.", | |
data_category=data_category, | |
cleaned_content_len=len(clean_model.content), | |
) | |
return clean_model | |
class ChunkingHandlerFactory: | |
def create_handler(data_category: DataCategory) -> ChunkingDataHandler: | |
if data_category == DataCategory.POSTS: | |
return PostChunkingHandler() | |
elif data_category == DataCategory.ARTICLES: | |
return ArticleChunkingHandler() | |
elif data_category == DataCategory.REPOSITORIES: | |
return RepositoryChunkingHandler() | |
else: | |
raise ValueError("Unsupported data type") | |
class ChunkingDispatcher: | |
cleaning_factory = ChunkingHandlerFactory | |
def dispatch(cls, data_model: VectorBaseDocument) -> list[VectorBaseDocument]: | |
data_category = data_model.get_category() | |
handler = cls.cleaning_factory.create_handler(data_category) | |
chunk_models = handler.chunk(data_model) | |
logger.info( | |
"Document chunked successfully.", | |
num=len(chunk_models), | |
data_category=data_category, | |
) | |
return chunk_models | |
class EmbeddingHandlerFactory: | |
def create_handler(data_category: DataCategory) -> EmbeddingDataHandler: | |
if data_category == DataCategory.QUERIES: | |
return QueryEmbeddingHandler() | |
if data_category == DataCategory.POSTS: | |
return PostEmbeddingHandler() | |
elif data_category == DataCategory.ARTICLES: | |
return ArticleEmbeddingHandler() | |
elif data_category == DataCategory.REPOSITORIES: | |
return RepositoryEmbeddingHandler() | |
else: | |
raise ValueError("Unsupported data type") | |
class EmbeddingDispatcher: | |
cleaning_factory = EmbeddingHandlerFactory | |
def dispatch( | |
cls, data_model: VectorBaseDocument | list[VectorBaseDocument] | |
) -> VectorBaseDocument | list[VectorBaseDocument]: | |
is_list = isinstance(data_model, list) | |
if not is_list: | |
data_model = [data_model] | |
if len(data_model) == 0: | |
return [] | |
data_category = data_model[0].get_category() | |
assert all( | |
data_model.get_category() == data_category for data_model in data_model | |
), "Data models must be of the same category." | |
handler = cls.cleaning_factory.create_handler(data_category) | |
embedded_chunk_model = handler.embed_batch(data_model) | |
if not is_list: | |
embedded_chunk_model = embedded_chunk_model[0] | |
logger.info( | |
"Data embedded successfully.", | |
data_category=data_category, | |
) | |
return embedded_chunk_model | |