File size: 5,610 Bytes
2eb9217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import itertools
import json

from datasets import load_dataset
import faiss
import pandas as pd
import numpy as np
import torch

from huggingface_hub import hf_hub_download
from sentence_transformers import SentenceTransformer


class InstructionTemplateRetriever:
    FINETEMPLATES_REVISION = "831ab22c90f9da011bd972585afdf609f40fa54b"
    RETRIEVAL_EMBEDDING_NAME = "fineinstructions/matching_embedding"
    RETRIEVAL_EMBEDDING_REVISION = "db4efbde126216250ffa5a356663fc7da3bf7856"

    def __init__(
        self,
        coverage_chunks=10,
        sigma=0.05,
        alpha=1.0,
        nprobe=150,
    ):
        """
        Computes embeddings that cover a document to find relevant
        instruction templates using Gaussian-weighted embeddings that cover
        different parts of the document.

        Args:
            coverage_chunks (int): The number of equally sized chunks/sections
            to get coverage over the entire document.
            sigma (float): Standard deviation for Gaussian weighting, this
            will essentially control how "wide" / "focused" each chunk is.
            alpha (float): A weighting factor to control how much to balance
            the representation of a single chunk, versus the representation of
            the entire document.
            nprobe (int): The number of probes to use when searching the FAISS
            index (larger is more accurate, but slower).
        """
        self.d = load_dataset(
            "fineinstructions/finetemplates",
            revision=InstructionTemplateRetriever.FINETEMPLATES_REVISION,
            split="full",
        )
        self.m = SentenceTransformer(
            InstructionTemplateRetriever.RETRIEVAL_EMBEDDING_NAME,
            revision=InstructionTemplateRetriever.RETRIEVAL_EMBEDDING_REVISION,
            device="cpu",
        )
        self.m = use_gaussian_coverage_pooling(
            self.m, coverage_chunks=coverage_chunks, sigma=sigma, alpha=alpha
        )
        self.index = faiss.read_index(
            hf_hub_download(
                "fineinstructions/finetemplates",
                "faiss_index/finetemplates.index",
                revision=InstructionTemplateRetriever.FINETEMPLATES_REVISION,
                repo_type="dataset",
            ),
            faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY,
        )
        self.index.nprobe = nprobe
        if torch.cuda.is_available():
            self.m = self.m.to("cuda")
        elif torch.backends.mps.is_available():
            self.m = self.m.to("mps")

    def _filter_rows(self, rows, filter_string):
        if not rows:
            return []
        df = pd.DataFrame(rows)
        try:
            filtered_df = df.query(filter_string)
            return filtered_df.to_dict(orient="records")
        except Exception as e:
            return rows

    def search(
        self, document, filters="", search_k=20000, max_results=250, deduplicate=True
    ):
        """
        Given a document

        Args:
            document (str): The document to retrieve relevant instruction templates for.
            filters (str): A query string in the format of pandas.DataFrame.query()
            search_k (int): The number of search results to pull when retrieving from FAISS.
            max_results (int): The max number of results to return.
            deduplicate (bool): Deduplicate results between coverage sections.
        """

        # Search FAISS index
        vecs = self.m.encode([document], normalize_embeddings=False).reshape(
            -1, self.m[0].auto_model.config.hidden_size
        )
        scores_batch, indices_batch = self.index.search(np.vstack(vecs), k=search_k)

        # Pull in FineTemplates rows into memory
        to_select = [i.item() for i in itertools.chain.from_iterable(indices_batch)]
        d_in_mem = {
            i: row for i, row in zip(to_select, self.d.select(to_select).to_list())
        }

        # Group by coverage chunk
        true_coverage_chunks = self.m[1].coverage_chunks + 1
        scores_per_input, indices_per_input = (
            [
                scores_batch[i : i + true_coverage_chunks]
                for i in range(0, len(scores_batch), true_coverage_chunks)
            ],
            [
                indices_batch[i : i + true_coverage_chunks]
                for i in range(0, len(indices_batch), true_coverage_chunks)
            ],
        )

        # Get the results for the first result in the batch (assuming bz=1)
        scores_per_input, indices_per_input = scores_per_input[0], indices_per_input[0]

        # Create result rows
        rows = [
            [
                {
                    "coverage_section": f"{chunk_idx}/{self.m[1].coverage_chunks}"
                    if chunk_idx > 0
                    else "Entire Document",
                    "score": s.item(),
                    **d_in_mem[i.item()],
                }
                for i, s in zip(indices, scores)
            ]
            for chunk_idx, (indices, scores) in enumerate(
                zip(indices_per_input, scores_per_input)
            )
        ]

        # Deduplicate
        if deduplicate:
            seen = set()
            rows = [
                r
                for r in itertools.chain.from_iterable(zip(*rows))
                if (len(seen) != len(seen.add(r["template_id"]) or seen))
            ]
        else:
            rows = list(itertools.chain.from_iterable(zip(*rows)))

        # Filter
        rows = self._filter_rows(rows, filters)[:max_results]

        # Return rows
        return rows