Spaces:
Running
Running
File size: 8,873 Bytes
54f5afe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
"""Compute and update an optimal query adapter."""
import numpy as np
from sqlmodel import Session, col, select
from tqdm.auto import tqdm
from raglite._config import RAGLiteConfig
from raglite._database import Chunk, ChunkEmbedding, Eval, IndexMetadata, create_database_engine
from raglite._embed import embed_sentences
from raglite._search import vector_search
def update_query_adapter( # noqa: PLR0915, C901
*,
max_triplets: int = 4096,
max_triplets_per_eval: int = 64,
optimize_top_k: int = 40,
config: RAGLiteConfig | None = None,
) -> None:
"""Compute an optimal query adapter and update the database with it.
This function computes an optimal linear transform A, called a 'query adapter', that is used to
transform a query embedding q as A @ q before searching for the nearest neighbouring chunks in
order to improve the quality of the search results.
Given a set of triplets (qᵢ, pᵢ, nᵢ), we want to find the query adapter A that increases the
score pᵢ'qᵢ of the positive chunk pᵢ and decreases the score nᵢ'qᵢ of the negative chunk nᵢ.
If the nearest neighbour search uses the dot product as its relevance score, we can find the
optimal query adapter by solving the following relaxed Procrustes optimisation problem with a
bound on the Frobenius norm of A:
A* = argmax Σᵢ pᵢ' (A qᵢ) - nᵢ' (A qᵢ)
Σᵢ (pᵢ - nᵢ)' A qᵢ
trace[ (P - N) A Q' ] where Q := [q₁'; ...; qₖ']
P := [p₁'; ...; pₖ']
N := [n₁'; ...; nₖ']
trace[ Q' (P - N) A ]
trace[ M A ] where M := Q' (P - N)
s.t. ||A||_F == 1
= M' / ||M||_F
If the nearest neighbour search uses the cosine similarity as its relevance score, we can find
the optimal query adapter by solving the following orthogonal Procrustes optimisation problem
with an orthogonality constraint on A:
A* = argmax Σᵢ pᵢ' (A qᵢ) - nᵢ' (A qᵢ)
Σᵢ (pᵢ - nᵢ)' A qᵢ
trace[ (P - N) A Q' ]
trace[ Q' (P - N) A ]
trace[ M A ]
trace[ U Σ V' A ] where U Σ V' := M is the SVD of M
trace[ Σ V' A U ]
s.t. A'A == 𝕀
= V U'
Additionally, we want to limit the effect of A* so that it adjusts q just enough to invert
incorrectly ordered (q, p, n) triplets, but not so much as to affect the correctly ordered ones.
To achieve this, we'll rewrite M as α(M / s) + (1 - α)𝕀, where s scales M to the same norm as 𝕀,
and choose the smallest α that ranks (q, p, n) correctly. If α = 0, the relevance score gap
between an incorrect (p, n) pair would be B := (p - n)' q < 0. If α = 1, the relevance score gap
would be A := (p - n)' (p - n) / ||p - n|| > 0. For a target relevance score gap of say
C := 5% * A, the optimal α is then given by αA + (1 - α)B = C => α = (B - C) / (B - A).
"""
config = config or RAGLiteConfig()
config_no_query_adapter = RAGLiteConfig(
**{**config.__dict__, "vector_search_query_adapter": False}
)
engine = create_database_engine(config)
with Session(engine) as session:
# Get random evals from the database.
chunk_embedding = session.exec(select(ChunkEmbedding).limit(1)).first()
if chunk_embedding is None:
error_message = "First run `insert_document()` to insert documents."
raise ValueError(error_message)
evals = session.exec(
select(Eval).order_by(Eval.id).limit(max(8, max_triplets // max_triplets_per_eval))
).all()
if len(evals) * max_triplets_per_eval < len(chunk_embedding.embedding):
error_message = "First run `insert_evals()` to generate sufficient evals."
raise ValueError(error_message)
# Loop over the evals to generate (q, p, n) triplets.
Q = np.zeros((0, len(chunk_embedding.embedding))) # noqa: N806
P = np.zeros_like(Q) # noqa: N806
N = np.zeros_like(Q) # noqa: N806
for eval_ in tqdm(
evals, desc="Extracting triplets from evals", unit="eval", dynamic_ncols=True
):
# Embed the question.
question_embedding = embed_sentences([eval_.question], config=config)
# Retrieve chunks that would be used to answer the question.
chunk_ids, _ = vector_search(
question_embedding, num_results=optimize_top_k, config=config_no_query_adapter
)
retrieved_chunks = session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all()
# Extract (q, p, n) triplets by comparing the retrieved chunks with the eval.
num_triplets = 0
for i, retrieved_chunk in enumerate(retrieved_chunks):
# Select irrelevant chunks.
if retrieved_chunk.id not in eval_.chunk_ids:
# Look up all positive chunks (each represented by the mean of its multi-vector
# embedding) that are ranked lower than this negative one (represented by the
# embedding in the multi-vector embedding that best matches the query).
p_mean = [
np.mean(chunk.embedding_matrix, axis=0, keepdims=True)
for chunk in retrieved_chunks[i + 1 :]
if chunk is not None and chunk.id in eval_.chunk_ids
]
n_top = retrieved_chunk.embedding_matrix[
np.argmax(retrieved_chunk.embedding_matrix @ question_embedding.T),
np.newaxis,
:,
]
# Filter out any (p, n, q) triplets for which the mean positive embedding ranks
# higher than the top negative one.
p_mean = [p_e for p_e in p_mean if (n_top - p_e) @ question_embedding.T > 0]
if not p_mean:
continue
# Stack the (p, n, q) triplets.
p = np.vstack(p_mean)
n = np.repeat(n_top, p.shape[0], axis=0)
q = np.repeat(question_embedding, p.shape[0], axis=0)
num_triplets += p.shape[0]
# Append the (query, positive, negative) tuples to the Q, P, N matrices.
Q = np.vstack([Q, q]) # noqa: N806
P = np.vstack([P, p]) # noqa: N806
N = np.vstack([N, n]) # noqa: N806
# Check if we have sufficient triplets for this eval.
if num_triplets >= max_triplets_per_eval:
break
# Check if we have sufficient triplets to compute the query adapter.
if Q.shape[0] > max_triplets:
Q, P, N = Q[:max_triplets, :], P[:max_triplets, :], N[:max_triplets, :] # noqa: N806
break
# Normalise the rows of Q, P, N.
Q /= np.linalg.norm(Q, axis=1, keepdims=True) # noqa: N806
P /= np.linalg.norm(P, axis=1, keepdims=True) # noqa: N806
N /= np.linalg.norm(N, axis=1, keepdims=True) # noqa: N806
# Compute the optimal weighted query adapter A*.
# TODO: Matmul in float16 is extremely slow compared to single or double precision, why?
gap_before = np.sum((P - N) * Q, axis=1)
gap_after = 2 * (1 - np.sum(P * N, axis=1)) / np.linalg.norm(P - N, axis=1)
gap_target = 0.05 * gap_after
α = (gap_before - gap_target) / (gap_before - gap_after) # noqa: PLC2401
MT = (α[:, np.newaxis] * (P - N)).T @ Q # noqa: N806
s = np.linalg.norm(MT, ord="fro") / np.sqrt(MT.shape[0])
MT = np.mean(α) * (MT / s) + np.mean(1 - α) * np.eye(Q.shape[1]) # noqa: N806
if config.vector_search_index_metric == "dot":
# Use the relaxed Procrustes solution.
A_star = MT / np.linalg.norm(MT, ord="fro") # noqa: N806
elif config.vector_search_index_metric == "cosine":
# Use the orthogonal Procrustes solution.
U, _, VT = np.linalg.svd(MT, full_matrices=False) # noqa: N806
A_star = U @ VT # noqa: N806
else:
error_message = f"Unsupported ANN metric: {config.vector_search_index_metric}"
raise ValueError(error_message)
# Store the optimal query adapter in the database.
index_metadata = session.get(IndexMetadata, "default") or IndexMetadata(id="default")
index_metadata.metadata_ = {**index_metadata.metadata_, "query_adapter": A_star}
session.add(index_metadata)
session.commit()
|