instruction_template_retrieval_embedding / instruction_template_retriever.py
AjayP13's picture
Update instruction_template_retriever.py
8681f11 verified
raw
history blame
5.61 kB
import itertools
import json
from datasets import load_dataset
import faiss
import pandas as pd
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from sentence_transformers import SentenceTransformer
class InstructionTemplateRetriever:
FINETEMPLATES_REVISION = "831ab22c90f9da011bd972585afdf609f40fa54b"
RETRIEVAL_EMBEDDING_NAME = "fineinstructions/matching_embedding"
RETRIEVAL_EMBEDDING_REVISION = "db4efbde126216250ffa5a356663fc7da3bf7856"
def __init__(
self,
coverage_chunks=10,
sigma=0.05,
alpha=1.0,
nprobe=150,
):
"""
Computes embeddings that cover a document to find relevant
instruction templates using Gaussian-weighted embeddings that cover
different parts of the document.
Args:
coverage_chunks (int): The number of equally sized chunks/sections
to get coverage over the entire document.
sigma (float): Standard deviation for Gaussian weighting, this
will essentially control how "wide" / "focused" each chunk is.
alpha (float): A weighting factor to control how much to balance
the representation of a single chunk, versus the representation of
the entire document.
nprobe (int): The number of probes to use when searching the FAISS
index (larger is more accurate, but slower).
"""
self.d = load_dataset(
"fineinstructions/finetemplates",
revision=InstructionTemplateRetriever.FINETEMPLATES_REVISION,
split="full",
)
self.m = SentenceTransformer(
InstructionTemplateRetriever.RETRIEVAL_EMBEDDING_NAME,
revision=InstructionTemplateRetriever.RETRIEVAL_EMBEDDING_REVISION,
device="cpu",
)
self.m = use_gaussian_coverage_pooling(
self.m, coverage_chunks=coverage_chunks, sigma=sigma, alpha=alpha
)
self.index = faiss.read_index(
hf_hub_download(
"fineinstructions/finetemplates",
"faiss_index/finetemplates.index",
revision=InstructionTemplateRetriever.FINETEMPLATES_REVISION,
repo_type="dataset",
),
faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY,
)
self.index.nprobe = nprobe
if torch.cuda.is_available():
self.m = self.m.to("cuda")
elif torch.backends.mps.is_available():
self.m = self.m.to("mps")
def _filter_rows(self, rows, filter_string):
if not rows:
return []
df = pd.DataFrame(rows)
try:
filtered_df = df.query(filter_string)
return filtered_df.to_dict(orient="records")
except Exception as e:
return rows
def search(
self, document, filters="", search_k=20000, max_results=250, deduplicate=True
):
"""
Given a document
Args:
document (str): The document to retrieve relevant instruction templates for.
filters (str): A query string in the format of pandas.DataFrame.query()
search_k (int): The number of search results to pull when retrieving from FAISS.
max_results (int): The max number of results to return.
deduplicate (bool): Deduplicate results between coverage sections.
"""
# Search FAISS index
vecs = self.m.encode([document], normalize_embeddings=False).reshape(
-1, self.m[0].auto_model.config.hidden_size
)
scores_batch, indices_batch = self.index.search(np.vstack(vecs), k=search_k)
# Pull in FineTemplates rows into memory
to_select = [i.item() for i in itertools.chain.from_iterable(indices_batch)]
d_in_mem = {
i: row for i, row in zip(to_select, self.d.select(to_select).to_list())
}
# Group by coverage chunk
true_coverage_chunks = self.m[1].coverage_chunks + 1
scores_per_input, indices_per_input = (
[
scores_batch[i : i + true_coverage_chunks]
for i in range(0, len(scores_batch), true_coverage_chunks)
],
[
indices_batch[i : i + true_coverage_chunks]
for i in range(0, len(indices_batch), true_coverage_chunks)
],
)
# Get the results for the first result in the batch (assuming bz=1)
scores_per_input, indices_per_input = scores_per_input[0], indices_per_input[0]
# Create result rows
rows = [
[
{
"coverage_section": f"{chunk_idx}/{self.m[1].coverage_chunks}"
if chunk_idx > 0
else "Entire Document",
"score": s.item(),
**d_in_mem[i.item()],
}
for i, s in zip(indices, scores)
]
for chunk_idx, (indices, scores) in enumerate(
zip(indices_per_input, scores_per_input)
)
]
# Deduplicate
if deduplicate:
seen = set()
rows = [
r
for r in itertools.chain.from_iterable(zip(*rows))
if (len(seen) != len(seen.add(r["template_id"]) or seen))
]
else:
rows = list(itertools.chain.from_iterable(zip(*rows)))
# Filter
rows = self._filter_rows(rows, filters)[:max_results]
# Return rows
return rows