File size: 8,873 Bytes
54f5afe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""Compute and update an optimal query adapter."""

import numpy as np
from sqlmodel import Session, col, select
from tqdm.auto import tqdm

from raglite._config import RAGLiteConfig
from raglite._database import Chunk, ChunkEmbedding, Eval, IndexMetadata, create_database_engine
from raglite._embed import embed_sentences
from raglite._search import vector_search


def update_query_adapter(  # noqa: PLR0915, C901
    *,
    max_triplets: int = 4096,
    max_triplets_per_eval: int = 64,
    optimize_top_k: int = 40,
    config: RAGLiteConfig | None = None,
) -> None:
    """Compute an optimal query adapter and update the database with it.

    This function computes an optimal linear transform A, called a 'query adapter', that is used to
    transform a query embedding q as A @ q before searching for the nearest neighbouring chunks in
    order to improve the quality of the search results.

    Given a set of triplets (qᵢ, pᵢ, nᵢ), we want to find the query adapter A that increases the
    score pᵢ'qᵢ of the positive chunk pᵢ and decreases the score nᵢ'qᵢ of the negative chunk nᵢ.

    If the nearest neighbour search uses the dot product as its relevance score, we can find the
    optimal query adapter by solving the following relaxed Procrustes optimisation problem with a
    bound on the Frobenius norm of A:

    A* = argmax Σᵢ pᵢ' (A qᵢ) - nᵢ' (A qᵢ)
                Σᵢ (pᵢ - nᵢ)' A qᵢ
                trace[ (P - N) A Q' ]  where  Q := [q₁'; ...; qₖ']
                                              P := [p₁'; ...; pₖ']
                                              N := [n₁'; ...; nₖ']
                trace[ Q' (P - N) A ]
                trace[ M A ]           where  M := Q' (P - N)
           s.t. ||A||_F == 1
       = M' / ||M||_F

    If the nearest neighbour search uses the cosine similarity as its relevance score, we can find
    the optimal query adapter by solving the following orthogonal Procrustes optimisation problem
    with an orthogonality constraint on A:

    A* = argmax Σᵢ pᵢ' (A qᵢ) - nᵢ' (A qᵢ)
                Σᵢ (pᵢ - nᵢ)' A qᵢ
                trace[ (P - N) A Q' ]
                trace[ Q' (P - N) A ]
                trace[ M A ]
                trace[ U Σ V' A ]      where  U Σ V' := M is the SVD of M
                trace[ Σ V' A U ]
           s.t. A'A == 𝕀
       = V U'

    Additionally, we want to limit the effect of A* so that it adjusts q just enough to invert
    incorrectly ordered (q, p, n) triplets, but not so much as to affect the correctly ordered ones.
    To achieve this, we'll rewrite M as α(M / s) + (1 - α)𝕀, where s scales M to the same norm as 𝕀,
    and choose the smallest α that ranks (q, p, n) correctly. If α = 0, the relevance score gap
    between an incorrect (p, n) pair would be B := (p - n)' q < 0. If α = 1, the relevance score gap
    would be A := (p - n)' (p - n) / ||p - n|| > 0. For a target relevance score gap of say
    C := 5% * A, the optimal α is then given by αA + (1 - α)B = C => α = (B - C) / (B - A).
    """
    config = config or RAGLiteConfig()
    config_no_query_adapter = RAGLiteConfig(
        **{**config.__dict__, "vector_search_query_adapter": False}
    )
    engine = create_database_engine(config)
    with Session(engine) as session:
        # Get random evals from the database.
        chunk_embedding = session.exec(select(ChunkEmbedding).limit(1)).first()
        if chunk_embedding is None:
            error_message = "First run `insert_document()` to insert documents."
            raise ValueError(error_message)
        evals = session.exec(
            select(Eval).order_by(Eval.id).limit(max(8, max_triplets // max_triplets_per_eval))
        ).all()
        if len(evals) * max_triplets_per_eval < len(chunk_embedding.embedding):
            error_message = "First run `insert_evals()` to generate sufficient evals."
            raise ValueError(error_message)
        # Loop over the evals to generate (q, p, n) triplets.
        Q = np.zeros((0, len(chunk_embedding.embedding)))  # noqa: N806
        P = np.zeros_like(Q)  # noqa: N806
        N = np.zeros_like(Q)  # noqa: N806
        for eval_ in tqdm(
            evals, desc="Extracting triplets from evals", unit="eval", dynamic_ncols=True
        ):
            # Embed the question.
            question_embedding = embed_sentences([eval_.question], config=config)
            # Retrieve chunks that would be used to answer the question.
            chunk_ids, _ = vector_search(
                question_embedding, num_results=optimize_top_k, config=config_no_query_adapter
            )
            retrieved_chunks = session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all()
            # Extract (q, p, n) triplets by comparing the retrieved chunks with the eval.
            num_triplets = 0
            for i, retrieved_chunk in enumerate(retrieved_chunks):
                # Select irrelevant chunks.
                if retrieved_chunk.id not in eval_.chunk_ids:
                    # Look up all positive chunks (each represented by the mean of its multi-vector
                    # embedding) that are ranked lower than this negative one (represented by the
                    # embedding in the multi-vector embedding that best matches the query).
                    p_mean = [
                        np.mean(chunk.embedding_matrix, axis=0, keepdims=True)
                        for chunk in retrieved_chunks[i + 1 :]
                        if chunk is not None and chunk.id in eval_.chunk_ids
                    ]
                    n_top = retrieved_chunk.embedding_matrix[
                        np.argmax(retrieved_chunk.embedding_matrix @ question_embedding.T),
                        np.newaxis,
                        :,
                    ]
                    # Filter out any (p, n, q) triplets for which the mean positive embedding ranks
                    # higher than the top negative one.
                    p_mean = [p_e for p_e in p_mean if (n_top - p_e) @ question_embedding.T > 0]
                    if not p_mean:
                        continue
                    # Stack the (p, n, q) triplets.
                    p = np.vstack(p_mean)
                    n = np.repeat(n_top, p.shape[0], axis=0)
                    q = np.repeat(question_embedding, p.shape[0], axis=0)
                    num_triplets += p.shape[0]
                    # Append the (query, positive, negative) tuples to the Q, P, N matrices.
                    Q = np.vstack([Q, q])  # noqa: N806
                    P = np.vstack([P, p])  # noqa: N806
                    N = np.vstack([N, n])  # noqa: N806
                    # Check if we have sufficient triplets for this eval.
                    if num_triplets >= max_triplets_per_eval:
                        break
            # Check if we have sufficient triplets to compute the query adapter.
            if Q.shape[0] > max_triplets:
                Q, P, N = Q[:max_triplets, :], P[:max_triplets, :], N[:max_triplets, :]  # noqa: N806
                break
        # Normalise the rows of Q, P, N.
        Q /= np.linalg.norm(Q, axis=1, keepdims=True)  # noqa: N806
        P /= np.linalg.norm(P, axis=1, keepdims=True)  # noqa: N806
        N /= np.linalg.norm(N, axis=1, keepdims=True)  # noqa: N806
        # Compute the optimal weighted query adapter A*.
        # TODO: Matmul in float16 is extremely slow compared to single or double precision, why?
        gap_before = np.sum((P - N) * Q, axis=1)
        gap_after = 2 * (1 - np.sum(P * N, axis=1)) / np.linalg.norm(P - N, axis=1)
        gap_target = 0.05 * gap_after
        α = (gap_before - gap_target) / (gap_before - gap_after)  # noqa: PLC2401
        MT = (α[:, np.newaxis] * (P - N)).T @ Q  # noqa: N806
        s = np.linalg.norm(MT, ord="fro") / np.sqrt(MT.shape[0])
        MT = np.mean(α) * (MT / s) + np.mean(1 - α) * np.eye(Q.shape[1])  # noqa: N806
        if config.vector_search_index_metric == "dot":
            # Use the relaxed Procrustes solution.
            A_star = MT / np.linalg.norm(MT, ord="fro")  # noqa: N806
        elif config.vector_search_index_metric == "cosine":
            # Use the orthogonal Procrustes solution.
            U, _, VT = np.linalg.svd(MT, full_matrices=False)  # noqa: N806
            A_star = U @ VT  # noqa: N806
        else:
            error_message = f"Unsupported ANN metric: {config.vector_search_index_metric}"
            raise ValueError(error_message)
        # Store the optimal query adapter in the database.
        index_metadata = session.get(IndexMetadata, "default") or IndexMetadata(id="default")
        index_metadata.metadata_ = {**index_metadata.metadata_, "query_adapter": A_star}
        session.add(index_metadata)
        session.commit()