rag_lite / src /raglite /_embed.py
EL GHAFRAOUI AYOUB
C
54f5afe
"""String embedder."""
from functools import partial
from typing import Literal
import numpy as np
from litellm import embedding
from llama_cpp import LLAMA_POOLING_TYPE_NONE, Llama
from tqdm.auto import tqdm, trange
from raglite._config import RAGLiteConfig
from raglite._litellm import LlamaCppPythonLLM
from raglite._typing import FloatMatrix, IntVector
def _embed_sentences_with_late_chunking( # noqa: PLR0915
sentences: list[str], *, config: RAGLiteConfig | None = None
) -> FloatMatrix:
"""Embed a document's sentences with late chunking."""
def _count_tokens(
sentences: list[str], embedder: Llama, sentinel_char: str, sentinel_tokens: list[int]
) -> list[int]:
# Join the sentences with the sentinel token and tokenise the result.
sentences_tokens = np.asarray(
embedder.tokenize(sentinel_char.join(sentences).encode(), add_bos=False), dtype=np.intp
)
# Map all sentinel token variants to the first one.
for sentinel_token in sentinel_tokens[1:]:
sentences_tokens[sentences_tokens == sentinel_token] = sentinel_tokens[0]
# Count how many tokens there are in between sentinel tokens to recover the token counts.
sentinel_indices = np.where(sentences_tokens == sentinel_tokens[0])[0]
num_tokens = np.diff(sentinel_indices, prepend=0, append=len(sentences_tokens))
assert len(num_tokens) == len(sentences), f"Sentinel `{sentinel_char}` appears in document"
num_tokens_list: list[int] = num_tokens.tolist()
return num_tokens_list
def _create_segment(
content_start_index: int,
max_tokens_preamble: int,
max_tokens_content: int,
num_tokens: IntVector,
) -> tuple[int, int]:
# Compute the segment sentence start index so that the segment preamble has no more than
# max_tokens_preamble tokens between [segment_start_index, content_start_index).
cumsum_backwards = np.cumsum(num_tokens[:content_start_index][::-1])
offset_preamble = np.searchsorted(cumsum_backwards, max_tokens_preamble, side="right")
segment_start_index = content_start_index - int(offset_preamble)
# Allow a larger segment content if we didn't use all of the allowed preamble tokens.
max_tokens_content = max_tokens_content + (
max_tokens_preamble - np.sum(num_tokens[segment_start_index:content_start_index])
)
# Compute the segment sentence end index so that the segment content has no more than
# max_tokens_content tokens between [content_start_index, segment_end_index).
cumsum_forwards = np.cumsum(num_tokens[content_start_index:])
offset_segment = np.searchsorted(cumsum_forwards, max_tokens_content, side="right")
segment_end_index = content_start_index + int(offset_segment)
return segment_start_index, segment_end_index
# Assert that we're using a llama-cpp-python model, since API-based embedding models don't
# support outputting token-level embeddings.
config = config or RAGLiteConfig()
assert config.embedder.startswith("llama-cpp-python")
embedder = LlamaCppPythonLLM.llm(
config.embedder, embedding=True, pooling_type=LLAMA_POOLING_TYPE_NONE
)
n_ctx = embedder.n_ctx()
n_batch = embedder.n_batch
# Identify the tokens corresponding to a sentinel character.
sentinel_char = "βŠ•"
sentinel_test = f"A{sentinel_char}B {sentinel_char} C.\n{sentinel_char}D"
sentinel_tokens = [
token
for token in embedder.tokenize(sentinel_test.encode(), add_bos=False)
if sentinel_char in embedder.detokenize([token]).decode()
]
assert len(sentinel_tokens), f"Sentinel `{sentinel_char}` not supported by embedder"
# Compute the number of tokens per sentence. We use a method based on a sentinel token to
# minimise the number of calls to embedder.tokenize, which incurs a significant overhead
# (presumably to load the tokenizer) [1].
# TODO: Make token counting faster and more robust once [1] is fixed.
# [1] https://github.com/abetlen/llama-cpp-python/issues/1763
num_tokens_list: list[int] = []
sentence_batch, sentence_batch_len = [], 0
for i, sentence in enumerate(sentences):
sentence_batch.append(sentence)
sentence_batch_len += len(sentence)
if i == len(sentences) - 1 or sentence_batch_len > (n_ctx // 2):
num_tokens_list.extend(
_count_tokens(sentence_batch, embedder, sentinel_char, sentinel_tokens)
)
sentence_batch, sentence_batch_len = [], 0
num_tokens = np.asarray(num_tokens_list, dtype=np.intp)
# Compute the maximum number of tokens for each segment's preamble and content.
# Unfortunately, llama-cpp-python truncates the input to n_batch tokens and crashes if you try
# to increase it [1]. Until this is fixed, we have to limit max_tokens to n_batch.
# TODO: Improve the context window size once [1] is fixed.
# [1] https://github.com/abetlen/llama-cpp-python/issues/1762
max_tokens = min(n_ctx, n_batch) - 16
max_tokens_preamble = round(0.382 * max_tokens) # Golden ratio.
max_tokens_content = max_tokens - max_tokens_preamble
# Compute a list of segments, each consisting of a preamble and content.
segments = []
content_start_index = 0
while content_start_index < len(sentences):
segment_start_index, segment_end_index = _create_segment(
content_start_index, max_tokens_preamble, max_tokens_content, num_tokens
)
segments.append((segment_start_index, content_start_index, segment_end_index))
content_start_index = segment_end_index
# Embed the segments and apply late chunking.
sentence_embeddings_list: list[FloatMatrix] = []
if len(segments) > 1 or segments[0][2] > 128: # noqa: PLR2004
segments = tqdm(segments, desc="Embedding", unit="segment", dynamic_ncols=True)
for segment in segments:
# Get the token embeddings of the entire segment, including preamble and content.
segment_start_index, content_start_index, segment_end_index = segment
segment_sentences = sentences[segment_start_index:segment_end_index]
segment_embedding = np.asarray(embedder.embed("".join(segment_sentences)))
# Split the segment embeddings into embedding matrices per sentence.
segment_tokens = num_tokens[segment_start_index:segment_end_index]
sentence_size = np.round(
len(segment_embedding) * (segment_tokens / np.sum(segment_tokens))
).astype(np.intp)
sentence_matrices = np.split(segment_embedding, np.cumsum(sentence_size)[:-1])
# Compute the segment sentence embeddings by averaging the token embeddings.
content_sentence_embeddings = [
np.mean(sentence_matrix, axis=0, keepdims=True)
for sentence_matrix in sentence_matrices[content_start_index - segment_start_index :]
]
sentence_embeddings_list.append(np.vstack(content_sentence_embeddings))
sentence_embeddings = np.vstack(sentence_embeddings_list)
# Normalise the sentence embeddings to unit norm and cast to half precision.
if config.embedder_normalize:
sentence_embeddings /= np.linalg.norm(sentence_embeddings, axis=1, keepdims=True)
sentence_embeddings = sentence_embeddings.astype(np.float16)
return sentence_embeddings
def _embed_sentences_with_windowing(
sentences: list[str], *, config: RAGLiteConfig | None = None
) -> FloatMatrix:
"""Embed a document's sentences with windowing."""
def _embed_string_batch(string_batch: list[str], *, config: RAGLiteConfig) -> FloatMatrix:
# Embed the batch of strings.
if config.embedder.startswith("llama-cpp-python"):
# LiteLLM doesn't yet support registering a custom embedder, so we handle it here.
# Additionally, we explicitly manually pool the token embeddings to obtain sentence
# embeddings because token embeddings are universally supported, while sequence
# embeddings are only supported by some models.
embedder = LlamaCppPythonLLM.llm(
config.embedder, embedding=True, pooling_type=LLAMA_POOLING_TYPE_NONE
)
embeddings = np.asarray([np.mean(row, axis=0) for row in embedder.embed(string_batch)])
else:
# Use LiteLLM's API to embed the batch of strings.
response = embedding(config.embedder, string_batch)
embeddings = np.asarray([item["embedding"] for item in response["data"]])
# Normalise the embeddings to unit norm and cast to half precision.
if config.embedder_normalize:
embeddings /= np.linalg.norm(embeddings, axis=1, keepdims=True)
embeddings = embeddings.astype(np.float16)
return embeddings
# Window the sentences with a lookback of `config.embedder_sentence_window_size - 1` sentences.
config = config or RAGLiteConfig()
sentence_windows = [
"".join(sentences[max(0, i - (config.embedder_sentence_window_size - 1)) : i + 1])
for i in range(len(sentences))
]
# Embed the sentence windows in batches.
batch_size = 64
batch_range = (
partial(trange, desc="Embedding", unit="batch", dynamic_ncols=True)
if len(sentence_windows) > batch_size
else range
)
batch_embeddings = [
_embed_string_batch(sentence_windows[i : i + batch_size], config=config)
for i in batch_range(0, len(sentence_windows), batch_size) # type: ignore[operator]
]
sentence_embeddings = np.vstack(batch_embeddings)
return sentence_embeddings
def sentence_embedding_type(
*,
config: RAGLiteConfig | None = None,
) -> Literal["late_chunking", "windowing"]:
"""Return the type of sentence embeddings."""
config = config or RAGLiteConfig()
return "late_chunking" if config.embedder.startswith("llama-cpp-python") else "windowing"
def embed_sentences(sentences: list[str], *, config: RAGLiteConfig | None = None) -> FloatMatrix:
"""Embed the sentences of a document as a NumPy matrix with one row per sentence."""
config = config or RAGLiteConfig()
if sentence_embedding_type(config=config) == "late_chunking":
sentence_embeddings = _embed_sentences_with_late_chunking(sentences, config=config)
else:
sentence_embeddings = _embed_sentences_with_windowing(sentences, config=config)
return sentence_embeddings