SkazuHD's picture
init space
d660b02
from typing_extensions import Annotated
from clearml import PipelineDecorator
from llm_engineering.application import utils
from llm_engineering.application.preprocessing import ChunkingDispatcher, EmbeddingDispatcher
from llm_engineering.domain.chunks import Chunk
from llm_engineering.domain.embedded_chunks import EmbeddedChunk
@PipelineDecorator.component(name="chunk_and_embed")
def chunk_and_embed(
cleaned_documents: Annotated[list, "cleaned_documents"],
) -> Annotated[list, "embedded_documents"]:
def _add_chunks_metadata(chunks: list[Chunk], metadata: dict) -> dict:
for chunk in chunks:
category = chunk.get_category()
if category not in metadata:
metadata[category] = chunk.metadata
if "authors" not in metadata[category]:
metadata[category]["authors"] = list()
metadata[category]["num_chunks"] = metadata[category].get("num_chunks", 0) + 1
metadata[category]["authors"].append(chunk.author_full_name)
for value in metadata.values():
if isinstance(value, dict) and "authors" in value:
value["authors"] = list(set(value["authors"]))
return metadata
def _add_embeddings_metadata(embedded_chunks: list[EmbeddedChunk], metadata: dict) -> dict:
for embedded_chunk in embedded_chunks:
category = embedded_chunk.get_category()
if category not in metadata:
metadata[category] = embedded_chunk.metadata
if "authors" not in metadata[category]:
metadata[category]["authors"] = list()
metadata[category]["authors"].append(embedded_chunk.author_full_name)
for value in metadata.values():
if isinstance(value, dict) and "authors" in value:
value["authors"] = list(set(value["authors"]))
return metadata
metadata = {"chunking": {}, "embedding": {}, "num_documents": len(cleaned_documents)}
embedded_chunks = []
for document in cleaned_documents:
chunks = ChunkingDispatcher.dispatch(document)
metadata["chunking"] = _add_chunks_metadata(chunks, metadata["chunking"])
for batched_chunks in utils.misc.batch(chunks, 10):
batched_embedded_chunks = EmbeddingDispatcher.dispatch(batched_chunks)
embedded_chunks.extend(batched_embedded_chunks)
metadata["embedding"] = _add_embeddings_metadata(embedded_chunks, metadata["embedding"])
metadata["num_chunks"] = len(embedded_chunks)
metadata["num_embedded_chunks"] = len(embedded_chunks)
#step_context = get_step_context()
#step_context.add_output_metadata(output_name="embedded_documents", metadata=metadata)
return embedded_chunks