|
import itertools |
|
import json |
|
|
|
from datasets import load_dataset |
|
import faiss |
|
import pandas as pd |
|
import numpy as np |
|
import torch |
|
|
|
from huggingface_hub import hf_hub_download |
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
class InstructionTemplateRetriever: |
|
FINETEMPLATES_REVISION = "831ab22c90f9da011bd972585afdf609f40fa54b" |
|
RETRIEVAL_EMBEDDING_NAME = "fineinstructions/matching_embedding" |
|
RETRIEVAL_EMBEDDING_REVISION = "db4efbde126216250ffa5a356663fc7da3bf7856" |
|
|
|
def __init__( |
|
self, |
|
coverage_chunks=10, |
|
sigma=0.05, |
|
alpha=1.0, |
|
nprobe=150, |
|
): |
|
""" |
|
Computes embeddings that cover a document to find relevant |
|
instruction templates using Gaussian-weighted embeddings that cover |
|
different parts of the document. |
|
|
|
Args: |
|
coverage_chunks (int): The number of equally sized chunks/sections |
|
to get coverage over the entire document. |
|
sigma (float): Standard deviation for Gaussian weighting, this |
|
will essentially control how "wide" / "focused" each chunk is. |
|
alpha (float): A weighting factor to control how much to balance |
|
the representation of a single chunk, versus the representation of |
|
the entire document. |
|
nprobe (int): The number of probes to use when searching the FAISS |
|
index (larger is more accurate, but slower). |
|
""" |
|
self.d = load_dataset( |
|
"fineinstructions/finetemplates", |
|
revision=InstructionTemplateRetriever.FINETEMPLATES_REVISION, |
|
split="full", |
|
) |
|
self.m = SentenceTransformer( |
|
InstructionTemplateRetriever.RETRIEVAL_EMBEDDING_NAME, |
|
revision=InstructionTemplateRetriever.RETRIEVAL_EMBEDDING_REVISION, |
|
device="cpu", |
|
) |
|
self.m = use_gaussian_coverage_pooling( |
|
self.m, coverage_chunks=coverage_chunks, sigma=sigma, alpha=alpha |
|
) |
|
self.index = faiss.read_index( |
|
hf_hub_download( |
|
"fineinstructions/finetemplates", |
|
"faiss_index/finetemplates.index", |
|
revision=InstructionTemplateRetriever.FINETEMPLATES_REVISION, |
|
repo_type="dataset", |
|
), |
|
faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY, |
|
) |
|
self.index.nprobe = nprobe |
|
if torch.cuda.is_available(): |
|
self.m = self.m.to("cuda") |
|
elif torch.backends.mps.is_available(): |
|
self.m = self.m.to("mps") |
|
|
|
def _filter_rows(self, rows, filter_string): |
|
if not rows: |
|
return [] |
|
df = pd.DataFrame(rows) |
|
try: |
|
filtered_df = df.query(filter_string) |
|
return filtered_df.to_dict(orient="records") |
|
except Exception as e: |
|
return rows |
|
|
|
def search( |
|
self, document, filters="", search_k=20000, max_results=250, deduplicate=True |
|
): |
|
""" |
|
Given a document |
|
|
|
Args: |
|
document (str): The document to retrieve relevant instruction templates for. |
|
filters (str): A query string in the format of pandas.DataFrame.query() |
|
search_k (int): The number of search results to pull when retrieving from FAISS. |
|
max_results (int): The max number of results to return. |
|
deduplicate (bool): Deduplicate results between coverage sections. |
|
""" |
|
|
|
|
|
vecs = self.m.encode([document], normalize_embeddings=False).reshape( |
|
-1, self.m[0].auto_model.config.hidden_size |
|
) |
|
scores_batch, indices_batch = self.index.search(np.vstack(vecs), k=search_k) |
|
|
|
|
|
to_select = [i.item() for i in itertools.chain.from_iterable(indices_batch)] |
|
d_in_mem = { |
|
i: row for i, row in zip(to_select, self.d.select(to_select).to_list()) |
|
} |
|
|
|
|
|
true_coverage_chunks = self.m[1].coverage_chunks + 1 |
|
scores_per_input, indices_per_input = ( |
|
[ |
|
scores_batch[i : i + true_coverage_chunks] |
|
for i in range(0, len(scores_batch), true_coverage_chunks) |
|
], |
|
[ |
|
indices_batch[i : i + true_coverage_chunks] |
|
for i in range(0, len(indices_batch), true_coverage_chunks) |
|
], |
|
) |
|
|
|
|
|
scores_per_input, indices_per_input = scores_per_input[0], indices_per_input[0] |
|
|
|
|
|
rows = [ |
|
[ |
|
{ |
|
"coverage_section": f"{chunk_idx}/{self.m[1].coverage_chunks}" |
|
if chunk_idx > 0 |
|
else "Entire Document", |
|
"score": s.item(), |
|
**d_in_mem[i.item()], |
|
} |
|
for i, s in zip(indices, scores) |
|
] |
|
for chunk_idx, (indices, scores) in enumerate( |
|
zip(indices_per_input, scores_per_input) |
|
) |
|
] |
|
|
|
|
|
if deduplicate: |
|
seen = set() |
|
rows = [ |
|
r |
|
for r in itertools.chain.from_iterable(zip(*rows)) |
|
if (len(seen) != len(seen.add(r["template_id"]) or seen)) |
|
] |
|
else: |
|
rows = list(itertools.chain.from_iterable(zip(*rows))) |
|
|
|
|
|
rows = self._filter_rows(rows, filters)[:max_results] |
|
|
|
|
|
return rows |
|
|