instruction_template_retrieval_embedding / instruction_template_retriever.py
AjayP13's picture
Update instruction_template_retriever.py
fe606d9 verified
import itertools
import json
import pickle
from random import Random
from datasets import load_dataset
import faiss
import pandas as pd
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from sentence_transformers import SentenceTransformer
class GaussianCoveragePooling(torch.nn.Module):
def __init__(self, coverage_chunks, sigma, alpha):
"""
Custom pooling layer that computes weighted mean pooling using Gaussian-based weights.
Args:
coverage_chunks (int): Number of weighted pooling operations (N).
sigma (float): Standard deviation for Gaussian weighting.
alpha (float): Weighting factor for merging with standard mean pooling.
"""
super().__init__()
self.coverage_chunks = coverage_chunks
self.sigma = sigma # Controls width of Gaussians
self.alpha = alpha # Blends standard mean with weighted mean
def forward(self, features, chunk_indicators=None):
"""
Computes weighted mean pooling using Gaussian-based weights.
Args:
self (SentenceTransformer): The model.
features (dict): The token embeddings and attention mask.
chunk_indicators (tensor[bz, 1]): Index indicators to return a specific chunk,
leave as None to return embeddings for all chunks. Mainly useful for training,
not inference. Leave as None for inference.
"""
# Get token embeddings and attention mask
token_embeddings = features[
"token_embeddings"
] # (batch_size, seq_len, hidden_dim)
attention_mask = (
features["attention_mask"].float().unsqueeze(-1)
) # (batch_size, seq_len, 1)
# Get shapes and devices
batch_size, seq_len, hidden_dim = token_embeddings.shape
device = token_embeddings.device
# Compute actual sequence lengths (ignoring padding)
# (batch_size, 1)
seq_lengths = attention_mask.squeeze(-1).sum(dim=1, keepdim=True)
max_seq_length = int(torch.max(seq_lengths).item())
# Standard mean pooling
sum_embeddings = torch.sum(token_embeddings * attention_mask, dim=1)
sum_mask = torch.sum(attention_mask, dim=1).clamp(min=1e-9)
standard_mean = sum_embeddings / sum_mask # (batch_size, hidden_dim)
# Compute chunk centers dynamically based on sequence length
chunk_positions = torch.linspace(0, 1, self.coverage_chunks + 2, device=device)[
1:-1
] # Excludes 0 and 1
chunk_centers = chunk_positions * seq_lengths # (batch_size, N)
# Token positions per sequence (batch_size, seq_len)
token_positions = (
torch.arange(seq_len, device=device).float().unsqueeze(0)
) # (1, seq_len)
# Compute Gaussian weights (batch_size, N, seq_len)
seq_lengths = seq_lengths.view(seq_lengths.shape[0], 1, 1).repeat(
1, self.coverage_chunks, max_seq_length
)
gaussians = torch.exp(
-0.5
* (
(token_positions.unsqueeze(1) - chunk_centers.unsqueeze(2))
/ (self.sigma * seq_lengths)
)
** 2
)
# Mask out padding and normalize Gaussian weights per sequence
# (batch_size, N, seq_len)
gaussians = gaussians * attention_mask.squeeze(-1).unsqueeze(1)
# Normalize against gaussian weights
gaussians /= gaussians.sum(dim=2, keepdim=True).clamp(min=1e-9)
# Compute weighted mean for each chunk (batch_size, N, hidden_dim)
weighted_means = torch.einsum(
"bns,bsh->bnh", gaussians.to(token_embeddings.dtype), token_embeddings
)
# Blend with standard mean pooling
# (batch_size, N, hidden_dim)
combined_embeddings = (1 - self.alpha) * standard_mean.unsqueeze(
1
) + self.alpha * weighted_means
# Add an embedding for the entire document at index 0
# (batch_size, N+1, hidden_dim)
combined_embeddings = torch.cat(
[torch.zeros_like(combined_embeddings[:, :1]), combined_embeddings], 1
)
combined_embeddings[:, 0:1, :] = standard_mean.unsqueeze(1)
# Select the indicator if provided
if chunk_indicators is not None:
combined_embeddings = combined_embeddings[
torch.arange(combined_embeddings.size(0)), chunk_indicators
]
# Normalize all the embeddings
combined_embeddings = torch.nn.functional.normalize(
combined_embeddings, p=2, dim=-1
)
# Flatten final embeddings (batch_size, hidden_dim * (N+1))
if chunk_indicators is None:
sentence_embedding = combined_embeddings.reshape(
batch_size, hidden_dim * (self.coverage_chunks + 1)
)
else:
sentence_embedding = combined_embeddings
# Return the final flattened entence embedding
features["sentence_embedding"] = sentence_embedding
return features
def use_gaussian_coverage_pooling(m, coverage_chunks=10, sigma=0.05, alpha=1.0):
"""
Add custom pooling layer that computes weighted mean pooling using Gaussian-based weights.
Args:
m (SentenceTransformer): The model to add pooling layer to.
coverage_chunks (int): Number of weighted pooling operations (N).
sigma (float): Standard deviation for Gaussian weighting.
alpha (float): Weighting factor for merging with standard mean pooling.
"""
if isinstance(m[1], GaussianCoveragePooling):
m = unuse_gaussian_coverage_pooling(m)
word_embedding_model = m[0]
custom_pooling = GaussianCoveragePooling(
coverage_chunks=coverage_chunks, sigma=sigma, alpha=alpha
)
old_pooling = m[1]
new_m = m.__class__(modules=[word_embedding_model, custom_pooling])
new_m.old_pooling = {"old_pooling": old_pooling}
return new_m
def unuse_gaussian_coverage_pooling(m):
"""
Removes the custom pooling layer.
Args:
m (SentenceTransformer): The model to remove the pooling layer from.
"""
if isinstance(m[1], GaussianCoveragePooling):
new_m = m.__class__(modules=[m[0], m.old_pooling["old_pooling"]])
return new_m
else:
return m
class InstructionTemplateRetriever:
FINETEMPLATES_REVISION = "4c8f22e0d6521a634ed12e3ebd4c438cf8f0c7fa"
RETRIEVAL_EMBEDDING_NAME = (
"fineinstructions/instruction_template_retrieval_embedding"
)
RETRIEVAL_EMBEDDING_REVISION = "db4efbde126216250ffa5a356663fc7da3bf7856"
def __init__(
self,
coverage_chunks=10,
sigma=0.05,
alpha=1.0,
nprobe=150,
):
"""
Computes embeddings that cover a document to find relevant
instruction templates using Gaussian-weighted embeddings that cover
different parts of the document.
Args:
coverage_chunks (int): The number of equally sized chunks/sections
to get coverage over the entire document.
sigma (float): Standard deviation for Gaussian weighting, this
will essentially control how "wide" / "focused" each chunk is.
alpha (float): A weighting factor to control how much to balance
the representation of a single chunk, versus the representation of
the entire document.
nprobe (int): The number of probes to use when searching the FAISS
index (larger is more accurate, but slower).
"""
self.d = load_dataset(
"fineinstructions/finetemplates",
revision=InstructionTemplateRetriever.FINETEMPLATES_REVISION,
split="full",
)
self.m = SentenceTransformer(
InstructionTemplateRetriever.RETRIEVAL_EMBEDDING_NAME,
revision=InstructionTemplateRetriever.RETRIEVAL_EMBEDDING_REVISION,
device="cpu",
)
self.m = use_gaussian_coverage_pooling(
self.m, coverage_chunks=coverage_chunks, sigma=sigma, alpha=alpha
)
self.index = faiss.read_index(
hf_hub_download(
"fineinstructions/finetemplates",
"faiss_index/finetemplates.index",
revision=InstructionTemplateRetriever.FINETEMPLATES_REVISION,
repo_type="dataset",
),
faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY,
)
self.index.nprobe = nprobe
if torch.cuda.is_available():
self.m = self.m.to("cuda")
elif torch.backends.mps.is_available():
self.m = self.m.to("mps")
with open(
hf_hub_download(
"fineinstructions/finetemplates",
"faiss_index/reweighting_stats.pkl",
revision=InstructionTemplateRetriever.FINETEMPLATES_REVISION,
repo_type="dataset",
),
"rb",
) as reweighting_stats_fp:
reweighting_stats = pickle.load(reweighting_stats_fp)
self.resampling_weights = reweighting_stats["resampling_weights"]
self.template_variable_count_mapping = reweighting_stats[
"template_variable_count_mapping"
]
def _filter_rows(self, rows, filter_string):
if not rows:
return []
df = pd.DataFrame(rows)
try:
filtered_df = df.query(filter_string)
return filtered_df.to_dict(orient="records")
except Exception as e:
return rows
def search(
self,
document,
filters="",
search_k=20000,
max_results=250,
deduplicate=True,
reweight=False,
reweight_epsilon=0.05,
):
"""
Given a document
Args:
document (str): The document to retrieve relevant instruction templates for.
filters (str): A query string in the format of pandas.DataFrame.query()
search_k (int): The number of search results to pull when retrieving from FAISS.
max_results (int): The max number of results to return.
deduplicate (bool): Deduplicate results between coverage sections.
reweight (bool): Whether to reweight the results based on a more realistic length distribution.
reweight_epsilon (float): How tolerant to be when reweighting (larger is more inaccurate results but better reweighting)
"""
def _reweight(inp, k=None):
if reweight:
inp0, inp = itertools.tee(inp)
first_row = next(inp0)
r = Random(first_row[1].item())
epsilon = reweight_epsilon
bucket = first_row[1]
items = []
weights = []
for i, s in inp:
if abs(bucket - s.item()) <= epsilon:
items.append((i, s))
weights.append(
self.resampling_weights[
self.template_variable_count_mapping[i.item()]
]
)
else:
break
return r.choices(
items, weights=weights, k=(len(items) if k is None else k)
)
else:
return inp
# Search FAISS index
vecs = self.m.encode([document], normalize_embeddings=False).reshape(
-1, self.m[0].auto_model.config.hidden_size
)
scores_batch, indices_batch = self.index.search(np.vstack(vecs), k=search_k)
# Pull in FineTemplates rows into memory
to_select = [i.item() for i in itertools.chain.from_iterable(indices_batch)]
d_in_mem = {
i: row for i, row in zip(to_select, self.d.select(to_select).to_list())
}
# Group by coverage chunk
true_coverage_chunks = self.m[1].coverage_chunks + 1
scores_per_input, indices_per_input = (
[
scores_batch[i : i + true_coverage_chunks]
for i in range(0, len(scores_batch), true_coverage_chunks)
],
[
indices_batch[i : i + true_coverage_chunks]
for i in range(0, len(indices_batch), true_coverage_chunks)
],
)
# Get the results for the first result in the batch (assuming bz=1)
scores_per_input, indices_per_input = scores_per_input[0], indices_per_input[0]
# Create result rows
rows = [
[
{
"coverage_section": f"{chunk_idx}/{self.m[1].coverage_chunks}"
if chunk_idx > 0
else "Entire Document",
"score": s.item(),
**d_in_mem[i.item()],
}
for i, s in _reweight(zip(indices, scores), k=None)
]
for chunk_idx, (indices, scores) in enumerate(
zip(indices_per_input, scores_per_input)
)
]
# Deduplicate
if deduplicate:
seen = set()
rows = [
r
for r in itertools.chain.from_iterable(zip(*rows))
if (len(seen) != len(seen.add(r["template_id"]) or seen))
]
else:
rows = list(itertools.chain.from_iterable(zip(*rows)))
# Filter
rows = self._filter_rows(rows, filters)[:max_results]
# Return rows
return rows