|
import itertools |
|
import json |
|
import pickle |
|
from random import Random |
|
|
|
from datasets import load_dataset |
|
import faiss |
|
import pandas as pd |
|
import numpy as np |
|
import torch |
|
|
|
from huggingface_hub import hf_hub_download |
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
class GaussianCoveragePooling(torch.nn.Module): |
|
def __init__(self, coverage_chunks, sigma, alpha): |
|
""" |
|
Custom pooling layer that computes weighted mean pooling using Gaussian-based weights. |
|
Args: |
|
coverage_chunks (int): Number of weighted pooling operations (N). |
|
sigma (float): Standard deviation for Gaussian weighting. |
|
alpha (float): Weighting factor for merging with standard mean pooling. |
|
""" |
|
super().__init__() |
|
self.coverage_chunks = coverage_chunks |
|
self.sigma = sigma |
|
self.alpha = alpha |
|
|
|
def forward(self, features, chunk_indicators=None): |
|
""" |
|
Computes weighted mean pooling using Gaussian-based weights. |
|
Args: |
|
self (SentenceTransformer): The model. |
|
features (dict): The token embeddings and attention mask. |
|
chunk_indicators (tensor[bz, 1]): Index indicators to return a specific chunk, |
|
leave as None to return embeddings for all chunks. Mainly useful for training, |
|
not inference. Leave as None for inference. |
|
""" |
|
|
|
|
|
token_embeddings = features[ |
|
"token_embeddings" |
|
] |
|
attention_mask = ( |
|
features["attention_mask"].float().unsqueeze(-1) |
|
) |
|
|
|
|
|
batch_size, seq_len, hidden_dim = token_embeddings.shape |
|
device = token_embeddings.device |
|
|
|
|
|
|
|
seq_lengths = attention_mask.squeeze(-1).sum(dim=1, keepdim=True) |
|
max_seq_length = int(torch.max(seq_lengths).item()) |
|
|
|
|
|
sum_embeddings = torch.sum(token_embeddings * attention_mask, dim=1) |
|
sum_mask = torch.sum(attention_mask, dim=1).clamp(min=1e-9) |
|
standard_mean = sum_embeddings / sum_mask |
|
|
|
|
|
chunk_positions = torch.linspace(0, 1, self.coverage_chunks + 2, device=device)[ |
|
1:-1 |
|
] |
|
chunk_centers = chunk_positions * seq_lengths |
|
|
|
|
|
token_positions = ( |
|
torch.arange(seq_len, device=device).float().unsqueeze(0) |
|
) |
|
|
|
|
|
seq_lengths = seq_lengths.view(seq_lengths.shape[0], 1, 1).repeat( |
|
1, self.coverage_chunks, max_seq_length |
|
) |
|
gaussians = torch.exp( |
|
-0.5 |
|
* ( |
|
(token_positions.unsqueeze(1) - chunk_centers.unsqueeze(2)) |
|
/ (self.sigma * seq_lengths) |
|
) |
|
** 2 |
|
) |
|
|
|
|
|
|
|
gaussians = gaussians * attention_mask.squeeze(-1).unsqueeze(1) |
|
|
|
|
|
gaussians /= gaussians.sum(dim=2, keepdim=True).clamp(min=1e-9) |
|
|
|
|
|
weighted_means = torch.einsum( |
|
"bns,bsh->bnh", gaussians.to(token_embeddings.dtype), token_embeddings |
|
) |
|
|
|
|
|
|
|
combined_embeddings = (1 - self.alpha) * standard_mean.unsqueeze( |
|
1 |
|
) + self.alpha * weighted_means |
|
|
|
|
|
|
|
combined_embeddings = torch.cat( |
|
[torch.zeros_like(combined_embeddings[:, :1]), combined_embeddings], 1 |
|
) |
|
combined_embeddings[:, 0:1, :] = standard_mean.unsqueeze(1) |
|
|
|
|
|
if chunk_indicators is not None: |
|
combined_embeddings = combined_embeddings[ |
|
torch.arange(combined_embeddings.size(0)), chunk_indicators |
|
] |
|
|
|
|
|
combined_embeddings = torch.nn.functional.normalize( |
|
combined_embeddings, p=2, dim=-1 |
|
) |
|
|
|
|
|
if chunk_indicators is None: |
|
sentence_embedding = combined_embeddings.reshape( |
|
batch_size, hidden_dim * (self.coverage_chunks + 1) |
|
) |
|
else: |
|
sentence_embedding = combined_embeddings |
|
|
|
|
|
features["sentence_embedding"] = sentence_embedding |
|
return features |
|
|
|
|
|
def use_gaussian_coverage_pooling(m, coverage_chunks=10, sigma=0.05, alpha=1.0): |
|
""" |
|
Add custom pooling layer that computes weighted mean pooling using Gaussian-based weights. |
|
Args: |
|
m (SentenceTransformer): The model to add pooling layer to. |
|
coverage_chunks (int): Number of weighted pooling operations (N). |
|
sigma (float): Standard deviation for Gaussian weighting. |
|
alpha (float): Weighting factor for merging with standard mean pooling. |
|
""" |
|
if isinstance(m[1], GaussianCoveragePooling): |
|
m = unuse_gaussian_coverage_pooling(m) |
|
word_embedding_model = m[0] |
|
custom_pooling = GaussianCoveragePooling( |
|
coverage_chunks=coverage_chunks, sigma=sigma, alpha=alpha |
|
) |
|
old_pooling = m[1] |
|
new_m = m.__class__(modules=[word_embedding_model, custom_pooling]) |
|
new_m.old_pooling = {"old_pooling": old_pooling} |
|
return new_m |
|
|
|
|
|
def unuse_gaussian_coverage_pooling(m): |
|
""" |
|
Removes the custom pooling layer. |
|
Args: |
|
m (SentenceTransformer): The model to remove the pooling layer from. |
|
""" |
|
|
|
if isinstance(m[1], GaussianCoveragePooling): |
|
new_m = m.__class__(modules=[m[0], m.old_pooling["old_pooling"]]) |
|
return new_m |
|
else: |
|
return m |
|
|
|
|
|
class InstructionTemplateRetriever: |
|
FINETEMPLATES_REVISION = "4c8f22e0d6521a634ed12e3ebd4c438cf8f0c7fa" |
|
RETRIEVAL_EMBEDDING_NAME = ( |
|
"fineinstructions/instruction_template_retrieval_embedding" |
|
) |
|
RETRIEVAL_EMBEDDING_REVISION = "db4efbde126216250ffa5a356663fc7da3bf7856" |
|
|
|
def __init__( |
|
self, |
|
coverage_chunks=10, |
|
sigma=0.05, |
|
alpha=1.0, |
|
nprobe=150, |
|
): |
|
""" |
|
Computes embeddings that cover a document to find relevant |
|
instruction templates using Gaussian-weighted embeddings that cover |
|
different parts of the document. |
|
|
|
Args: |
|
coverage_chunks (int): The number of equally sized chunks/sections |
|
to get coverage over the entire document. |
|
sigma (float): Standard deviation for Gaussian weighting, this |
|
will essentially control how "wide" / "focused" each chunk is. |
|
alpha (float): A weighting factor to control how much to balance |
|
the representation of a single chunk, versus the representation of |
|
the entire document. |
|
nprobe (int): The number of probes to use when searching the FAISS |
|
index (larger is more accurate, but slower). |
|
""" |
|
self.d = load_dataset( |
|
"fineinstructions/finetemplates", |
|
revision=InstructionTemplateRetriever.FINETEMPLATES_REVISION, |
|
split="full", |
|
) |
|
self.m = SentenceTransformer( |
|
InstructionTemplateRetriever.RETRIEVAL_EMBEDDING_NAME, |
|
revision=InstructionTemplateRetriever.RETRIEVAL_EMBEDDING_REVISION, |
|
device="cpu", |
|
) |
|
self.m = use_gaussian_coverage_pooling( |
|
self.m, coverage_chunks=coverage_chunks, sigma=sigma, alpha=alpha |
|
) |
|
self.index = faiss.read_index( |
|
hf_hub_download( |
|
"fineinstructions/finetemplates", |
|
"faiss_index/finetemplates.index", |
|
revision=InstructionTemplateRetriever.FINETEMPLATES_REVISION, |
|
repo_type="dataset", |
|
), |
|
faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY, |
|
) |
|
self.index.nprobe = nprobe |
|
if torch.cuda.is_available(): |
|
self.m = self.m.to("cuda") |
|
elif torch.backends.mps.is_available(): |
|
self.m = self.m.to("mps") |
|
|
|
with open( |
|
hf_hub_download( |
|
"fineinstructions/finetemplates", |
|
"faiss_index/reweighting_stats.pkl", |
|
revision=InstructionTemplateRetriever.FINETEMPLATES_REVISION, |
|
repo_type="dataset", |
|
), |
|
"rb", |
|
) as reweighting_stats_fp: |
|
reweighting_stats = pickle.load(reweighting_stats_fp) |
|
self.resampling_weights = reweighting_stats["resampling_weights"] |
|
self.template_variable_count_mapping = reweighting_stats[ |
|
"template_variable_count_mapping" |
|
] |
|
|
|
def _filter_rows(self, rows, filter_string): |
|
if not rows: |
|
return [] |
|
df = pd.DataFrame(rows) |
|
try: |
|
filtered_df = df.query(filter_string) |
|
return filtered_df.to_dict(orient="records") |
|
except Exception as e: |
|
return rows |
|
|
|
def search( |
|
self, |
|
document, |
|
filters="", |
|
search_k=20000, |
|
max_results=250, |
|
deduplicate=True, |
|
reweight=False, |
|
reweight_epsilon=0.05, |
|
): |
|
""" |
|
Given a document |
|
|
|
Args: |
|
document (str): The document to retrieve relevant instruction templates for. |
|
filters (str): A query string in the format of pandas.DataFrame.query() |
|
search_k (int): The number of search results to pull when retrieving from FAISS. |
|
max_results (int): The max number of results to return. |
|
deduplicate (bool): Deduplicate results between coverage sections. |
|
reweight (bool): Whether to reweight the results based on a more realistic length distribution. |
|
reweight_epsilon (float): How tolerant to be when reweighting (larger is more inaccurate results but better reweighting) |
|
""" |
|
|
|
def _reweight(inp, k=None): |
|
if reweight: |
|
inp0, inp = itertools.tee(inp) |
|
first_row = next(inp0) |
|
r = Random(first_row[1].item()) |
|
epsilon = reweight_epsilon |
|
bucket = first_row[1] |
|
items = [] |
|
weights = [] |
|
for i, s in inp: |
|
if abs(bucket - s.item()) <= epsilon: |
|
items.append((i, s)) |
|
weights.append( |
|
self.resampling_weights[ |
|
self.template_variable_count_mapping[i.item()] |
|
] |
|
) |
|
else: |
|
break |
|
return r.choices( |
|
items, weights=weights, k=(len(items) if k is None else k) |
|
) |
|
else: |
|
return inp |
|
|
|
|
|
vecs = self.m.encode([document], normalize_embeddings=False).reshape( |
|
-1, self.m[0].auto_model.config.hidden_size |
|
) |
|
scores_batch, indices_batch = self.index.search(np.vstack(vecs), k=search_k) |
|
|
|
|
|
to_select = [i.item() for i in itertools.chain.from_iterable(indices_batch)] |
|
d_in_mem = { |
|
i: row for i, row in zip(to_select, self.d.select(to_select).to_list()) |
|
} |
|
|
|
|
|
true_coverage_chunks = self.m[1].coverage_chunks + 1 |
|
scores_per_input, indices_per_input = ( |
|
[ |
|
scores_batch[i : i + true_coverage_chunks] |
|
for i in range(0, len(scores_batch), true_coverage_chunks) |
|
], |
|
[ |
|
indices_batch[i : i + true_coverage_chunks] |
|
for i in range(0, len(indices_batch), true_coverage_chunks) |
|
], |
|
) |
|
|
|
|
|
scores_per_input, indices_per_input = scores_per_input[0], indices_per_input[0] |
|
|
|
|
|
rows = [ |
|
[ |
|
{ |
|
"coverage_section": f"{chunk_idx}/{self.m[1].coverage_chunks}" |
|
if chunk_idx > 0 |
|
else "Entire Document", |
|
"score": s.item(), |
|
**d_in_mem[i.item()], |
|
} |
|
for i, s in _reweight(zip(indices, scores), k=None) |
|
] |
|
for chunk_idx, (indices, scores) in enumerate( |
|
zip(indices_per_input, scores_per_input) |
|
) |
|
] |
|
|
|
|
|
if deduplicate: |
|
seen = set() |
|
rows = [ |
|
r |
|
for r in itertools.chain.from_iterable(zip(*rows)) |
|
if (len(seen) != len(seen.add(r["template_id"]) or seen)) |
|
] |
|
else: |
|
rows = list(itertools.chain.from_iterable(zip(*rows))) |
|
|
|
|
|
rows = self._filter_rows(rows, filters)[:max_results] |
|
|
|
|
|
return rows |
|
|