"""Compute and update an optimal query adapter.""" import numpy as np from sqlmodel import Session, col, select from tqdm.auto import tqdm from raglite._config import RAGLiteConfig from raglite._database import Chunk, ChunkEmbedding, Eval, IndexMetadata, create_database_engine from raglite._embed import embed_sentences from raglite._search import vector_search def update_query_adapter( # noqa: PLR0915, C901 *, max_triplets: int = 4096, max_triplets_per_eval: int = 64, optimize_top_k: int = 40, config: RAGLiteConfig | None = None, ) -> None: """Compute an optimal query adapter and update the database with it. This function computes an optimal linear transform A, called a 'query adapter', that is used to transform a query embedding q as A @ q before searching for the nearest neighbouring chunks in order to improve the quality of the search results. Given a set of triplets (qᵢ, pᵢ, nᵢ), we want to find the query adapter A that increases the score pᵢ'qᵢ of the positive chunk pᵢ and decreases the score nᵢ'qᵢ of the negative chunk nᵢ. If the nearest neighbour search uses the dot product as its relevance score, we can find the optimal query adapter by solving the following relaxed Procrustes optimisation problem with a bound on the Frobenius norm of A: A* = argmax Σᵢ pᵢ' (A qᵢ) - nᵢ' (A qᵢ) Σᵢ (pᵢ - nᵢ)' A qᵢ trace[ (P - N) A Q' ] where Q := [q₁'; ...; qₖ'] P := [p₁'; ...; pₖ'] N := [n₁'; ...; nₖ'] trace[ Q' (P - N) A ] trace[ M A ] where M := Q' (P - N) s.t. ||A||_F == 1 = M' / ||M||_F If the nearest neighbour search uses the cosine similarity as its relevance score, we can find the optimal query adapter by solving the following orthogonal Procrustes optimisation problem with an orthogonality constraint on A: A* = argmax Σᵢ pᵢ' (A qᵢ) - nᵢ' (A qᵢ) Σᵢ (pᵢ - nᵢ)' A qᵢ trace[ (P - N) A Q' ] trace[ Q' (P - N) A ] trace[ M A ] trace[ U Σ V' A ] where U Σ V' := M is the SVD of M trace[ Σ V' A U ] s.t. A'A == 𝕀 = V U' Additionally, we want to limit the effect of A* so that it adjusts q just enough to invert incorrectly ordered (q, p, n) triplets, but not so much as to affect the correctly ordered ones. To achieve this, we'll rewrite M as α(M / s) + (1 - α)𝕀, where s scales M to the same norm as 𝕀, and choose the smallest α that ranks (q, p, n) correctly. If α = 0, the relevance score gap between an incorrect (p, n) pair would be B := (p - n)' q < 0. If α = 1, the relevance score gap would be A := (p - n)' (p - n) / ||p - n|| > 0. For a target relevance score gap of say C := 5% * A, the optimal α is then given by αA + (1 - α)B = C => α = (B - C) / (B - A). """ config = config or RAGLiteConfig() config_no_query_adapter = RAGLiteConfig( **{**config.__dict__, "vector_search_query_adapter": False} ) engine = create_database_engine(config) with Session(engine) as session: # Get random evals from the database. chunk_embedding = session.exec(select(ChunkEmbedding).limit(1)).first() if chunk_embedding is None: error_message = "First run `insert_document()` to insert documents." raise ValueError(error_message) evals = session.exec( select(Eval).order_by(Eval.id).limit(max(8, max_triplets // max_triplets_per_eval)) ).all() if len(evals) * max_triplets_per_eval < len(chunk_embedding.embedding): error_message = "First run `insert_evals()` to generate sufficient evals." raise ValueError(error_message) # Loop over the evals to generate (q, p, n) triplets. Q = np.zeros((0, len(chunk_embedding.embedding))) # noqa: N806 P = np.zeros_like(Q) # noqa: N806 N = np.zeros_like(Q) # noqa: N806 for eval_ in tqdm( evals, desc="Extracting triplets from evals", unit="eval", dynamic_ncols=True ): # Embed the question. question_embedding = embed_sentences([eval_.question], config=config) # Retrieve chunks that would be used to answer the question. chunk_ids, _ = vector_search( question_embedding, num_results=optimize_top_k, config=config_no_query_adapter ) retrieved_chunks = session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all() # Extract (q, p, n) triplets by comparing the retrieved chunks with the eval. num_triplets = 0 for i, retrieved_chunk in enumerate(retrieved_chunks): # Select irrelevant chunks. if retrieved_chunk.id not in eval_.chunk_ids: # Look up all positive chunks (each represented by the mean of its multi-vector # embedding) that are ranked lower than this negative one (represented by the # embedding in the multi-vector embedding that best matches the query). p_mean = [ np.mean(chunk.embedding_matrix, axis=0, keepdims=True) for chunk in retrieved_chunks[i + 1 :] if chunk is not None and chunk.id in eval_.chunk_ids ] n_top = retrieved_chunk.embedding_matrix[ np.argmax(retrieved_chunk.embedding_matrix @ question_embedding.T), np.newaxis, :, ] # Filter out any (p, n, q) triplets for which the mean positive embedding ranks # higher than the top negative one. p_mean = [p_e for p_e in p_mean if (n_top - p_e) @ question_embedding.T > 0] if not p_mean: continue # Stack the (p, n, q) triplets. p = np.vstack(p_mean) n = np.repeat(n_top, p.shape[0], axis=0) q = np.repeat(question_embedding, p.shape[0], axis=0) num_triplets += p.shape[0] # Append the (query, positive, negative) tuples to the Q, P, N matrices. Q = np.vstack([Q, q]) # noqa: N806 P = np.vstack([P, p]) # noqa: N806 N = np.vstack([N, n]) # noqa: N806 # Check if we have sufficient triplets for this eval. if num_triplets >= max_triplets_per_eval: break # Check if we have sufficient triplets to compute the query adapter. if Q.shape[0] > max_triplets: Q, P, N = Q[:max_triplets, :], P[:max_triplets, :], N[:max_triplets, :] # noqa: N806 break # Normalise the rows of Q, P, N. Q /= np.linalg.norm(Q, axis=1, keepdims=True) # noqa: N806 P /= np.linalg.norm(P, axis=1, keepdims=True) # noqa: N806 N /= np.linalg.norm(N, axis=1, keepdims=True) # noqa: N806 # Compute the optimal weighted query adapter A*. # TODO: Matmul in float16 is extremely slow compared to single or double precision, why? gap_before = np.sum((P - N) * Q, axis=1) gap_after = 2 * (1 - np.sum(P * N, axis=1)) / np.linalg.norm(P - N, axis=1) gap_target = 0.05 * gap_after α = (gap_before - gap_target) / (gap_before - gap_after) # noqa: PLC2401 MT = (α[:, np.newaxis] * (P - N)).T @ Q # noqa: N806 s = np.linalg.norm(MT, ord="fro") / np.sqrt(MT.shape[0]) MT = np.mean(α) * (MT / s) + np.mean(1 - α) * np.eye(Q.shape[1]) # noqa: N806 if config.vector_search_index_metric == "dot": # Use the relaxed Procrustes solution. A_star = MT / np.linalg.norm(MT, ord="fro") # noqa: N806 elif config.vector_search_index_metric == "cosine": # Use the orthogonal Procrustes solution. U, _, VT = np.linalg.svd(MT, full_matrices=False) # noqa: N806 A_star = U @ VT # noqa: N806 else: error_message = f"Unsupported ANN metric: {config.vector_search_index_metric}" raise ValueError(error_message) # Store the optimal query adapter in the database. index_metadata = session.get(IndexMetadata, "default") or IndexMetadata(id="default") index_metadata.metadata_ = {**index_metadata.metadata_, "query_adapter": A_star} session.add(index_metadata) session.commit()