Spaces:
Running
Running
File size: 10,550 Bytes
54f5afe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
"""String embedder."""
from functools import partial
from typing import Literal
import numpy as np
from litellm import embedding
from llama_cpp import LLAMA_POOLING_TYPE_NONE, Llama
from tqdm.auto import tqdm, trange
from raglite._config import RAGLiteConfig
from raglite._litellm import LlamaCppPythonLLM
from raglite._typing import FloatMatrix, IntVector
def _embed_sentences_with_late_chunking( # noqa: PLR0915
sentences: list[str], *, config: RAGLiteConfig | None = None
) -> FloatMatrix:
"""Embed a document's sentences with late chunking."""
def _count_tokens(
sentences: list[str], embedder: Llama, sentinel_char: str, sentinel_tokens: list[int]
) -> list[int]:
# Join the sentences with the sentinel token and tokenise the result.
sentences_tokens = np.asarray(
embedder.tokenize(sentinel_char.join(sentences).encode(), add_bos=False), dtype=np.intp
)
# Map all sentinel token variants to the first one.
for sentinel_token in sentinel_tokens[1:]:
sentences_tokens[sentences_tokens == sentinel_token] = sentinel_tokens[0]
# Count how many tokens there are in between sentinel tokens to recover the token counts.
sentinel_indices = np.where(sentences_tokens == sentinel_tokens[0])[0]
num_tokens = np.diff(sentinel_indices, prepend=0, append=len(sentences_tokens))
assert len(num_tokens) == len(sentences), f"Sentinel `{sentinel_char}` appears in document"
num_tokens_list: list[int] = num_tokens.tolist()
return num_tokens_list
def _create_segment(
content_start_index: int,
max_tokens_preamble: int,
max_tokens_content: int,
num_tokens: IntVector,
) -> tuple[int, int]:
# Compute the segment sentence start index so that the segment preamble has no more than
# max_tokens_preamble tokens between [segment_start_index, content_start_index).
cumsum_backwards = np.cumsum(num_tokens[:content_start_index][::-1])
offset_preamble = np.searchsorted(cumsum_backwards, max_tokens_preamble, side="right")
segment_start_index = content_start_index - int(offset_preamble)
# Allow a larger segment content if we didn't use all of the allowed preamble tokens.
max_tokens_content = max_tokens_content + (
max_tokens_preamble - np.sum(num_tokens[segment_start_index:content_start_index])
)
# Compute the segment sentence end index so that the segment content has no more than
# max_tokens_content tokens between [content_start_index, segment_end_index).
cumsum_forwards = np.cumsum(num_tokens[content_start_index:])
offset_segment = np.searchsorted(cumsum_forwards, max_tokens_content, side="right")
segment_end_index = content_start_index + int(offset_segment)
return segment_start_index, segment_end_index
# Assert that we're using a llama-cpp-python model, since API-based embedding models don't
# support outputting token-level embeddings.
config = config or RAGLiteConfig()
assert config.embedder.startswith("llama-cpp-python")
embedder = LlamaCppPythonLLM.llm(
config.embedder, embedding=True, pooling_type=LLAMA_POOLING_TYPE_NONE
)
n_ctx = embedder.n_ctx()
n_batch = embedder.n_batch
# Identify the tokens corresponding to a sentinel character.
sentinel_char = "⊕"
sentinel_test = f"A{sentinel_char}B {sentinel_char} C.\n{sentinel_char}D"
sentinel_tokens = [
token
for token in embedder.tokenize(sentinel_test.encode(), add_bos=False)
if sentinel_char in embedder.detokenize([token]).decode()
]
assert len(sentinel_tokens), f"Sentinel `{sentinel_char}` not supported by embedder"
# Compute the number of tokens per sentence. We use a method based on a sentinel token to
# minimise the number of calls to embedder.tokenize, which incurs a significant overhead
# (presumably to load the tokenizer) [1].
# TODO: Make token counting faster and more robust once [1] is fixed.
# [1] https://github.com/abetlen/llama-cpp-python/issues/1763
num_tokens_list: list[int] = []
sentence_batch, sentence_batch_len = [], 0
for i, sentence in enumerate(sentences):
sentence_batch.append(sentence)
sentence_batch_len += len(sentence)
if i == len(sentences) - 1 or sentence_batch_len > (n_ctx // 2):
num_tokens_list.extend(
_count_tokens(sentence_batch, embedder, sentinel_char, sentinel_tokens)
)
sentence_batch, sentence_batch_len = [], 0
num_tokens = np.asarray(num_tokens_list, dtype=np.intp)
# Compute the maximum number of tokens for each segment's preamble and content.
# Unfortunately, llama-cpp-python truncates the input to n_batch tokens and crashes if you try
# to increase it [1]. Until this is fixed, we have to limit max_tokens to n_batch.
# TODO: Improve the context window size once [1] is fixed.
# [1] https://github.com/abetlen/llama-cpp-python/issues/1762
max_tokens = min(n_ctx, n_batch) - 16
max_tokens_preamble = round(0.382 * max_tokens) # Golden ratio.
max_tokens_content = max_tokens - max_tokens_preamble
# Compute a list of segments, each consisting of a preamble and content.
segments = []
content_start_index = 0
while content_start_index < len(sentences):
segment_start_index, segment_end_index = _create_segment(
content_start_index, max_tokens_preamble, max_tokens_content, num_tokens
)
segments.append((segment_start_index, content_start_index, segment_end_index))
content_start_index = segment_end_index
# Embed the segments and apply late chunking.
sentence_embeddings_list: list[FloatMatrix] = []
if len(segments) > 1 or segments[0][2] > 128: # noqa: PLR2004
segments = tqdm(segments, desc="Embedding", unit="segment", dynamic_ncols=True)
for segment in segments:
# Get the token embeddings of the entire segment, including preamble and content.
segment_start_index, content_start_index, segment_end_index = segment
segment_sentences = sentences[segment_start_index:segment_end_index]
segment_embedding = np.asarray(embedder.embed("".join(segment_sentences)))
# Split the segment embeddings into embedding matrices per sentence.
segment_tokens = num_tokens[segment_start_index:segment_end_index]
sentence_size = np.round(
len(segment_embedding) * (segment_tokens / np.sum(segment_tokens))
).astype(np.intp)
sentence_matrices = np.split(segment_embedding, np.cumsum(sentence_size)[:-1])
# Compute the segment sentence embeddings by averaging the token embeddings.
content_sentence_embeddings = [
np.mean(sentence_matrix, axis=0, keepdims=True)
for sentence_matrix in sentence_matrices[content_start_index - segment_start_index :]
]
sentence_embeddings_list.append(np.vstack(content_sentence_embeddings))
sentence_embeddings = np.vstack(sentence_embeddings_list)
# Normalise the sentence embeddings to unit norm and cast to half precision.
if config.embedder_normalize:
sentence_embeddings /= np.linalg.norm(sentence_embeddings, axis=1, keepdims=True)
sentence_embeddings = sentence_embeddings.astype(np.float16)
return sentence_embeddings
def _embed_sentences_with_windowing(
sentences: list[str], *, config: RAGLiteConfig | None = None
) -> FloatMatrix:
"""Embed a document's sentences with windowing."""
def _embed_string_batch(string_batch: list[str], *, config: RAGLiteConfig) -> FloatMatrix:
# Embed the batch of strings.
if config.embedder.startswith("llama-cpp-python"):
# LiteLLM doesn't yet support registering a custom embedder, so we handle it here.
# Additionally, we explicitly manually pool the token embeddings to obtain sentence
# embeddings because token embeddings are universally supported, while sequence
# embeddings are only supported by some models.
embedder = LlamaCppPythonLLM.llm(
config.embedder, embedding=True, pooling_type=LLAMA_POOLING_TYPE_NONE
)
embeddings = np.asarray([np.mean(row, axis=0) for row in embedder.embed(string_batch)])
else:
# Use LiteLLM's API to embed the batch of strings.
response = embedding(config.embedder, string_batch)
embeddings = np.asarray([item["embedding"] for item in response["data"]])
# Normalise the embeddings to unit norm and cast to half precision.
if config.embedder_normalize:
embeddings /= np.linalg.norm(embeddings, axis=1, keepdims=True)
embeddings = embeddings.astype(np.float16)
return embeddings
# Window the sentences with a lookback of `config.embedder_sentence_window_size - 1` sentences.
config = config or RAGLiteConfig()
sentence_windows = [
"".join(sentences[max(0, i - (config.embedder_sentence_window_size - 1)) : i + 1])
for i in range(len(sentences))
]
# Embed the sentence windows in batches.
batch_size = 64
batch_range = (
partial(trange, desc="Embedding", unit="batch", dynamic_ncols=True)
if len(sentence_windows) > batch_size
else range
)
batch_embeddings = [
_embed_string_batch(sentence_windows[i : i + batch_size], config=config)
for i in batch_range(0, len(sentence_windows), batch_size) # type: ignore[operator]
]
sentence_embeddings = np.vstack(batch_embeddings)
return sentence_embeddings
def sentence_embedding_type(
*,
config: RAGLiteConfig | None = None,
) -> Literal["late_chunking", "windowing"]:
"""Return the type of sentence embeddings."""
config = config or RAGLiteConfig()
return "late_chunking" if config.embedder.startswith("llama-cpp-python") else "windowing"
def embed_sentences(sentences: list[str], *, config: RAGLiteConfig | None = None) -> FloatMatrix:
"""Embed the sentences of a document as a NumPy matrix with one row per sentence."""
config = config or RAGLiteConfig()
if sentence_embedding_type(config=config) == "late_chunking":
sentence_embeddings = _embed_sentences_with_late_chunking(sentences, config=config)
else:
sentence_embeddings = _embed_sentences_with_windowing(sentences, config=config)
return sentence_embeddings
|